From e59641c2dbf5818fe79d732bb65130a7424b2e7e Mon Sep 17 00:00:00 2001 From: Alex Ororbia Date: Fri, 1 May 2026 13:03:12 -0400 Subject: [PATCH 01/23] Align main with release (#150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Final nudge to v3.1.0 (#149) * implemented raw classical instantaneous stdp * mod to classical stdp * mod to classical stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * minor mod of syn * cleaned up stdp syn * cleaned up stdp syn * Sync up of main with release (#131) * Minor nudge to v3.0.1 (#129) * minor edit to math in hh-lesson doc * Fix workflow, numpy install, and pytest bug in github action workflows (#117) * Update pyproject.toml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * minor nudge/cleanup to minor patched version 2.0.1 * minor nudge/cleanup to minor patched version 2.0.3 * Merged back minor doc fix back to main (for syncing purposes) (#119) * Nudge of release to minor patched version 2.0.3 (#118) * nudge of doc to 2.0.2 (#115) Co-authored-by: Alexander Ororbia * minor edit to math in hh-lesson doc * Fix workflow, numpy install, and pytest bug in github action workflows (#117) * Update pyproject.toml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * Update python-package-conda.yml * minor nudge/cleanup to minor patched version 2.0.1 * minor nudge/cleanup to minor patched version 2.0.3 --------- Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> * fixed typo/error in doc evolving_synapses.md --------- Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> * minor clean-up in model_basics docs * minor fixes/cleanup of docs * fixed typo in integration tutorial doc * updated papers/talk page for ngclearn * Merging over v3 to main (for roll-out of v3 upgrade) (#125) * Working v3 * Undid fixed compartemts Undid the fixed compartments to work with new global constant tracking * Fixed an execution bug * ported over quad-lif to v3 - needs testing * ported over IF/quadLIF cells, minor revision to LIF cell * Start util cleanup * refactored/ported RAFCell to v3 * ported over/refactored WTASCell for v3 * wrote successful unit-test of WTASCell * put back in init-structure/pointers * fixed minor error in LIFCell, got unit-test for LIFCell to run * quad-lif test sketched * sketch of ifcell test * fixed minor bugs and tests locally pass for if, quad-lif, and lif-cells now, with minor patches to help fun and doc-strings * refactored raf-cell and test passed * refactored adex/test passed; minor cleanup in lif, raf, and wtas cells * refactored fn-cell and test passed * cleaned up lif, raf, wtas, fn, and quad-lif cells repr method * refactored and tests passed for izh and h-h cells * JaxProcess update * cleaned up dunder repr method, moved to JaxComponent parent; fixed __init__ pointer to tensorstats * refactored alpha and exp-synapses, tests passed; minor edit to __init__ for synapses * refactored short-term syn, tests passed - including stp-dense-syn and minor cleanup/edit to synapse __init__ * refactored bcm-syn and test passed * refactored exp-stdp-syn and passed tests for exp-stdp-syn and trace-stdp-syn * refactored event-stdp-syn and test passed * refactored mstdpet-syn and test passed * refactored stdp-conv-syn/conv-syn and test passed * refactored and passed test for deconv/stdp-deconv-syn and other minor cleanup for conv/deconv support * Refactoring neuronal and synaptic components (#123) - merge from fork to v3 * refactoring graded cells * update refactored models * update sLIF cell --------- Co-authored-by: Alex Ororbia * commented out deprecator in hebb-syn and exp-kernel * update hebbian synapse * update hebbian synapse * working reinforce synapse * minor edits to exp-kernel/wtas-cell * update requirements * refactored conv/deconv-hebb-syn and tests passed * update hebbian synapse reset bug * update reset methods * update patched synapse reset * add `not self.inputs.targeted and ` to required components. Fixing general `__repr__` bug in `jaxcomponent` * minor edit to lif/modulated-syn init file * fixed some minor bugs in rate-coded cells/hebb-syn * update code * minor patches to components, including hebb-syn/conv/deconv and reward-cell * minor patches to components, including hebb-syn/conv/deconv and reward-cell * update testing for graded neurons and input encoders * update phasor cell * update test bernoulli cell and poisson cell * update components and their related test cases * fixed monitor bugs from v2, tweaked unit-tests for input-encoders/latency-cell * update test case for test_sLIFCell.py * some cleanup * made revisions to components/clean-up; added back in deprecators * removed lava sub-module, and removed monitor/base-monitor legacy components * minor cleanup of inits * refactored regression module to be compliant with v3 * adjusted sphinx-docs w.r.t. new v3 refactoring * minor revision to double-exp syn pointing, mods to modeling docs * updated adex tutorial doc to v3 * revised adex and error-cell neurocog tutorials * fixed minor issues in input-encoders, further revisions to docs for v3 * revised dyn/chem-syn neurocog doc, cleaned up dynamic syn * revised fn and hh-cell neurocog docs, added some refs to distribution generator * revised integration and izh-cell neurocog docs * revised izh-cell, cleaned-up fn-cell, and revised lif neurocog docs * revised metrics/plotting neurocog docs * revised mod/reward-stdp neurocog doc * revised stp-syn neurocog doc and updated stp-syn to use proper initializer * revised elements of utils to comply with docs * revised stdp neurocog doc to v3 * revised traces neurocog tutorial to v3 * cleaned up utils.optim and wrote compliant NAG optim * cleaned up utils.optim and wrote compliant NAG optim * cleanup of components, added leaky-noise-cell, minor edits * revised leaky-noise-cell, wrote its unit test, test-passed * some revisions/updates to toc/pointer/general tutorial docs * minor revisions to pyproject/req files * update reinforce synapse * update test cases * implemented in-house gmm, in-built to ngclearn; tested on gaussian mode data * wrote gmm density estimator tutorial * patched some tests/syn/neuron components, added sketch of bmm density * fixed test_laplacianErrorCell and laplace-cell bug * fixed test_laplacianErrorCell and laplace-cell bug * made patches to bmm * updated density tutorial/neurocog doc * minor edit to gmm/bmm docs * minor edit to gmm/bmm docs * cleaned up density structure, use parent mixture class to organize model variations * cleaned up density structure, use parent mixture class to organize model variations * added basic exp-mixture to utils.density * minor edits to emm * cleaned up mixtures and finished debugging EMM/works on example * removed old weight_distribution.py, other cleanup/revisions throughout * minor edit to data-loader * revised tests to no longer use weight_distribution/revisions throughout * minor edit to emm doc * added bic calculation to metric_utils * fix ratecell ug of passing unrelated kwargs to parent class * added calc_update() co-routine to hebbian-syn component * fix weight init * integrated rbm/harmonium model-exhibit * Update __init__.py Added the config/logging back to the init * placed pointer to rao-ballard1999 exhibit; updates to docs * updates to docs/revisions * removed flag from bernoulli/latency-cells for now; minor edit to doc * updates to theory doc * updated history log * minor clean-up of ngclearn.utils.viz.dim_reduce * Update jaxComponent.py Added support for turning off autosave * update hebbian synapse saving * update saving and loading utils, making hebbian synapse use these utils for custom optimizer params saving and loading * minor revisions/polish * modded docs to include v3 foundations * updates to init for logging * Updates to lessons * final cleanup/polish/update to docs for v3 nudge * updates to museum doc for v3 * nudged citation file * minor nudge to docs/files to point to v3 --------- Co-authored-by: Will Gebhardt Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Co-authored-by: Viet Nguyen Co-authored-by: Viet Dung Nguyen * update to rbm/harmonium doc * updated leaky-noise-cell to maintain temporal derivative of state * minor revisons/updates to hebb/dense syn, metric utils * cleaned-up/revised leaky-noise-cell * cleaned-up/revised leaky-noise-cell --------- Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Co-authored-by: Will Gebhardt Co-authored-by: Viet Nguyen Co-authored-by: Viet Dung Nguyen * nudge release to v3.0.1 * minor revision of leaky-noise-cell --------- Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Co-authored-by: Will Gebhardt Co-authored-by: Viet Nguyen Co-authored-by: Viet Dung Nguyen * added pointer/stub for ei-rnn song-et-al in museum doc * update to ei-rnn doc * update to ei-rnn arch fig * added log-gaussian initializer to distribution_generator * bug-fix to log-gaussian func * Refactor patch utility functions and add doc strings (#136) * Rao1999 hpc (#135) * Enhance documentation for predictive coding model Expanded the documentation for the predictive coding model, detailing the construction of neural and synaptic components, process dynamics, and training procedures. * Add files via upload * Update image path for GEC in documentation * Revise PC model documentation formatting Updated headings and formatting for clarity in PC model documentation. * Change header levels for PC model training sections * Add files via upload * Update PC model training section with input image details Added explanation about the input image for the PC model training. * Update pc_rao_ballard1999.md * Delete docs/images/museum/hgpc/Patch_input.png * Add files via upload * Fix image source and enhance PC model description Updated image source and adjusted description for clarity. * Refactor PC model training sections in documentation Removed the section for training the PC model on the full image and added a reference to it in the patched image section. * fixed minor errors in pc-rao doc * made revisions to pc-rao doc * mod to pc-rao doc * update to docs * minor revision to h-h doc-string * added lkwta utility * Add retinal ganglion cell input encoder (#137) * Add RetinalGanglionCell component with filtering methods Implement RetinalGanglionCell with Gaussian filtering and patch extraction. * Add RetinalGanglionCell import to input_encoders * Add RetinalGanglionCell to input encoders * Enhance filter functions in ganglionCell.py Refactor Gaussian filter creation and add Difference of Gaussian filter functionality. * Refactor patch synapse (#138) * Refactor multi-patch synapse creation and initialization Refactor _create_multi_patch_synapses function to use n_modules instead of n_sub_models and update weight initialization. Introduce weight masks for synaptic weights. * Refactor HebbianPatchedSynapse and add attributes Refactor HebbianPatchedSynapse initialization and add new attributes for post-in and pre-out. * feat: Integrate MPSSynapse Component (#140) * feat: integrate MPSSynapse component for compressed synaptic transforms * style: conform to Google docstrings, move utils, and add unit tests * feat: implement native learning via evolve method and unit tests * Fixed MPS Matrix Properties: I fixed the .T transpose bug you interrupted earlier—because self.W10.weights inside an MPSSynapse generates the tensor via an einsum, returning an Array, get() throws an error. * Fix MPS synapse memory leak by implementing project_backward * Delete uv.lock * docs: add academic references and detailed docstrings to MPSSynapse * sorry, here you go, I loosened the test tolerances to 1e-2 as suggested --------- Co-authored-by: Alex Ororbia * integrated working som-synapse into competitive sub-package for synapses * cleaned up som-syn * update test code for hebbian patch synapse * fix SOM Synapse bug * Flexible batch size (#142) * Modify reset method to accept batch_size parameter for flexible test set size * Modify reset method to accept batch_size parameter * Refactor RateCell class reset function for flexible batch size * Refactor GaussianErrorCell class functions for flexible batch size * flexible batch_size * cleaned up graded/patched comps with inner batched_reset formulation * minor clean-up of som-syn * claned up ganglion-cell, added batched_reset * minor cleanup * added working hopfield-syn/modern-hopfield-syn * update SOM synapse to batchified version * integrated prototype for vector-quantize memory model/synapse * wrote/integrated an ART2A synapse model, batch-generalized * updates to art2a, cleanup of probes * updates to art2a, cleanup of probes * added in knn-probe for utils.analysis * cleaned up vq-synapse * cleaned up vq-synapse * tweaked/cleaned-up gaussian-error-cell * Update JaxProcessesMixin.py Added automatic jit wrapping which is on by default, add "use_jit=False" to a process to disable * minor patch fixes, including making .mask a compartment in key syn * patch to bernoulli/latency and wtas cells * update reset function of the ganglion cell Co-authored-by: Copilot * minor mod to model_utils * docs now with a few more mods * Nudge to release of v3.1.0 (#146) * create release branch * Dev (#62) * implemented raw classical instantaneous stdp * mod to classical stdp * mod to classical stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * minor mod of syn * cleaned up stdp syn * cleaned up stdp syn --------- Co-authored-by: ago109 * added block diag init to weight dist * added block diag init to weight dist * added block diag init to weight dist; with will optimization * slight extension to rate-cell for tensor-shaping * slight extension to rate-cell for tensor-shaping * slight extension to rate-cell for tensor-shaping * slightly modded bernoulli-cell help to reflect correct compartment names * nudge to readme for minor version shift to beta2 * nudge to readme for minor version shift to beta2 * mod to docs to prep for nudge to beta2 * nudge correctly to pip version beta3 * generalized rate-cell a bit * touched up rate-cell further * minor mod to lif * updated lif-cell to use units/tags and minor cleanup and edits * Monitor plot (#66) * Update base_monitor.py * added plotting viewed compartments * added meta-data to rate-cell, input encoders, adex * fixed minor saving/loading in rate-cell w/ vectorized compartments * Added auto resolving for monitors (#67) * fixed surr arg in lif-cell * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * Patched synapses added (#68) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Fix typo in pcn_discrim.md (#69) * model_utils and rate cell (#70) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * Create hierarchical_sc.md 1 * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Delete docs/images/hgpc_network.pdf * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create hgpc * Delete docs/images/museum/hgpc * Create d * Add files via upload * Delete docs/images/hgpc_model.png * Delete docs/images/museum/hgpc/d * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Delete docs/images/museum/hgpc/Input_layer.png * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create Generative_PC.md * Update and rename Generative_PC.md to generative_pc.md * Update generative_pc.md * Update generative_pc.md * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update rateCell.py * Update generative_pc.md * Create pc-sindy.md * Update pc-sindy.md * Update model_utils.py sine activation function added * Update model_utils.py * Update ode_utils.py jitified * Delete docs/museum/hierarchical_sc.md * Delete docs/museum/generative_pc.md * Delete ngclearn/components/synapses/patched directory * Update __init__.py * Add files via upload ode with scanner added * Update ode_solver.py _ removed * Fix/reorganize feature library (#74) * Update ode_utils.py * Update ode_solver.py rk4 revised and __main__ added * Delete ngclearn/utils/diffeq/ode_functions.py * Create odes.py odes name and structure changed * Update __init__.py * Create feature_library.py * Create __init__.py * Create base.py * Delete docs/museum/pc-sindy.md * Create m.md * Add files via upload * Delete docs/images/museum/sindy/m.md * Add files via upload * Create sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * fix: correct feature library path and directory name * Delete ngclearn/utils/dymbolic_dictionary directory * Update model_utils.py (#78) * Additions for inhibition stuff * add sindy documentation for exhibits (#81) * Add files via upload * Add files via upload * Update ode_utils.py (#79) refactor: delete @partial(jit, static_argnums=(2, )) lines Co-authored-by: Will Gebhardt * Add patched synapse (#80) * Update __init__.py Add point to patched components * Add patched in __init__.py Add patched synapses importing * Add patched synaptic components * Delete ngclearn/components/synapses/patched/__pycache__ directory * Update __init__.py new line characters added * Update hebbianPatchedSynapse.py * Update patchedSynapse.py new line characters added * Update staticPatchedSynapse.py new line characters added * Update staticPatchedSynapse.py New line characters + comments for describing each input vars * Update patchedSynapse.py Removed a comment line * Update hebbianPatchedSynapse.py remove unused arguments * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py add description for w_mask * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py * Update patchedSynapse.py * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update __init__.py (#83) * Update __init__.py typo fixed * Update staticPatchedSynapse.py a typo fixed * Update hebbianPatchedSynapse.py typo foxed * Add l1 decay term to update calculation (#84) * Update hebbianSynapse.py * update main update main at the end * Update hebbianSynapse.py add regularization argument and w_decay is deprecated. * Update hebbianSynapse.py add elastic_net * Update hebbianSynapse.py * Update hebbianSynapse.py * feat NGC module regression (#86) * feat npc module regression * Update __init__.py * Update __init__.py * Update elastic_net.py * Update lasso.py * Update ridge.py * Update elastic_net.py * Update ridge.py * Update lasso.py * Update odes.py removed @partial(jit, static_argnums=(0,)) * Update odes.py (#87) removed @partial(jit, static_argnums=(0,)) * Update odes.py typo fixed in __main__ * Update __init__.py add dot * Update __init__.py add dot * Add attribute 'lr' (#90) * Update elastic_net.py * add lr as attribute to lasso.py * add lr as attribute to ridge.py * refactor w_bound=0. for weights elastic_net.py deactivated w_bound for weights elastic_net.py * Update lasso.py * deactivated w_bound for weights ridge.py * commit probes/mods to utils to analysis_tools branch * commit probes/mods to utils to analysis_tools branch * update documentation * cleaned up probes/docs for probes * change heads_dim to attn_dim, and modify the mlp to be as similar as possible to the attentive probing pattern * in layer normalization or any other Gaussian, standardeviation can never be zero. Additionally, if the subtraction inside the square root goes to zero, the gradient will become NaN. Therefore, adding a clipping is necessary. * update attentive probe code * minor tweak to attentive prob code comments * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * minor edits to attn probe * update attentive probe with input layer norm * update input layer normalization * update code to fix nan bug * minor tweak to attn probe * cleaned up probes * cleaned up probes * cleaned up probes * cleaned up probes * generalized dropout in terms of shape * tweak to atten probe * tweak to atten probe * added silu/swish/elu to model_utils * cleaned up model_utils * fix bug in attention probe dropout, fix bug in None noise_key passed in the probing jit function, add the spliting of noise_keys to two dropout in two cross attention * hyperparameter tunning arguments added * Merging over Dynamics feature branch to main (#92) * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Additions for inhibition stuff * update to API modeling docs to reflect RAF neuronal cell --------- Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * remove unused local variables * update note * update model utils * remove notes * Update ode utils (#94) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * minor fix to header in diffeq * Update files with ode_solver (#95) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * Update odes.py * Update sindy.md ode_solver to ode_utils * revised/cleaned up sindy tutorial doc/imgs * add prior for hebbian patched synapse (#96) * prior replaced w_decay hebbianPatchedSynapse.py remove w_decay add prior_type and prior_lmbda * revised typo hebbianSynapse.py dWweight was typo * cleaned up doc-strings in odes.py to comply w/ ngc-learn format * minor tweak to sig-figs printing in probe utils * add-sigma-to-gaussianErrorCell (#97) * add-sigma-to-gaussianErrorCell add not updating scalar variance for gaussian errors * Update gaussianErrorCell.py * cleaned up ode_utils, cleaned up gaussian/laplacian cell * Update gaussianErrorCell.py (#98) added `and not isinstance(sigma, int)` * cleaned up gauss/laplace error cells * integrated bernoulli err-cell * Major release update merge to main (in prep for 2.0.0 release on release branch/pip) (#99) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * update test cases --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update (to 2.0.0) (#100) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update merge to main (sync up) (#101) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * major cleanup in prep for merge over to main/prep for major release * update test cases * update to require file in docs --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * update test cases * Nudging main v2.0.0 to release stage (formal release) (#102) * generalized rate-cell a bit * touched up rate-cell further * minor mod to lif * updated lif-cell to use units/tags and minor cleanup and edits * Monitor plot (#66) * Update base_monitor.py * added plotting viewed compartments * added meta-data to rate-cell, input encoders, adex * fixed minor saving/loading in rate-cell w/ vectorized compartments * Added auto resolving for monitors (#67) * fixed surr arg in lif-cell * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * Patched synapses added (#68) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Fix typo in pcn_discrim.md (#69) * model_utils and rate cell (#70) * Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added * Create hierarchical_sc.md 1 * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update sparse_coding.md * Update sparse_coding.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Delete docs/images/hgpc_network.pdf * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create hgpc * Delete docs/images/museum/hgpc * Create d * Add files via upload * Delete docs/images/hgpc_model.png * Delete docs/images/museum/hgpc/d * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Delete docs/images/museum/hgpc/Input_layer.png * Add files via upload * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Update hierarchical_sc.md * Create Generative_PC.md * Update and rename Generative_PC.md to generative_pc.md * Update generative_pc.md * Update generative_pc.md * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update model_utils.py * Update rateCell.py * Update generative_pc.md * Create pc-sindy.md * Update pc-sindy.md * Update model_utils.py sine activation function added * Update model_utils.py * Update ode_utils.py jitified * Delete docs/museum/hierarchical_sc.md * Delete docs/museum/generative_pc.md * Delete ngclearn/components/synapses/patched directory * Update __init__.py * Add files via upload ode with scanner added * Update ode_solver.py _ removed * Fix/reorganize feature library (#74) * Update ode_utils.py * Update ode_solver.py rk4 revised and __main__ added * Delete ngclearn/utils/diffeq/ode_functions.py * Create odes.py odes name and structure changed * Update __init__.py * Create feature_library.py * Create __init__.py * Create base.py * Delete docs/museum/pc-sindy.md * Create m.md * Add files via upload * Delete docs/images/museum/sindy/m.md * Add files via upload * Create sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * Update sindy.md * fix: correct feature library path and directory name * Delete ngclearn/utils/dymbolic_dictionary directory * Update model_utils.py (#78) * Additions for inhibition stuff * add sindy documentation for exhibits (#81) * Add files via upload * Add files via upload * Update ode_utils.py (#79) refactor: delete @partial(jit, static_argnums=(2, )) lines Co-authored-by: Will Gebhardt * Add patched synapse (#80) * Update __init__.py Add point to patched components * Add patched in __init__.py Add patched synapses importing * Add patched synaptic components * Delete ngclearn/components/synapses/patched/__pycache__ directory * Update __init__.py new line characters added * Update hebbianPatchedSynapse.py * Update patchedSynapse.py new line characters added * Update staticPatchedSynapse.py new line characters added * Update staticPatchedSynapse.py New line characters + comments for describing each input vars * Update patchedSynapse.py Removed a comment line * Update hebbianPatchedSynapse.py remove unused arguments * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py add description for w_mask * Update hebbianPatchedSynapse.py * Update hebbianPatchedSynapse.py * Update patchedSynapse.py * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update __init__.py (#83) * Update __init__.py typo fixed * Update staticPatchedSynapse.py a typo fixed * Update hebbianPatchedSynapse.py typo foxed * Add l1 decay term to update calculation (#84) * Update hebbianSynapse.py * update main update main at the end * Update hebbianSynapse.py add regularization argument and w_decay is deprecated. * Update hebbianSynapse.py add elastic_net * Update hebbianSynapse.py * Update hebbianSynapse.py * feat NGC module regression (#86) * feat npc module regression * Update __init__.py * Update __init__.py * Update elastic_net.py * Update lasso.py * Update ridge.py * Update elastic_net.py * Update ridge.py * Update lasso.py * Update odes.py removed @partial(jit, static_argnums=(0,)) * Update odes.py (#87) removed @partial(jit, static_argnums=(0,)) * Update odes.py typo fixed in __main__ * Update __init__.py add dot * Update __init__.py add dot * Add attribute 'lr' (#90) * Update elastic_net.py * add lr as attribute to lasso.py * add lr as attribute to ridge.py * refactor w_bound=0. for weights elastic_net.py deactivated w_bound for weights elastic_net.py * Update lasso.py * deactivated w_bound for weights ridge.py * commit probes/mods to utils to analysis_tools branch * commit probes/mods to utils to analysis_tools branch * update documentation * cleaned up probes/docs for probes * change heads_dim to attn_dim, and modify the mlp to be as similar as possible to the attentive probing pattern * in layer normalization or any other Gaussian, standardeviation can never be zero. Additionally, if the subtraction inside the square root goes to zero, the gradient will become NaN. Therefore, adding a clipping is necessary. * update attentive probe code * minor tweak to attentive prob code comments * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * cleaned up probe parent fit routine * minor edits to attn probe * update attentive probe with input layer norm * update input layer normalization * update code to fix nan bug * minor tweak to attn probe * cleaned up probes * cleaned up probes * cleaned up probes * cleaned up probes * generalized dropout in terms of shape * tweak to atten probe * tweak to atten probe * added silu/swish/elu to model_utils * cleaned up model_utils * fix bug in attention probe dropout, fix bug in None noise_key passed in the probing jit function, add the spliting of noise_keys to two dropout in two cross attention * hyperparameter tunning arguments added * Merging over Dynamics feature branch to main (#92) * modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Additions for inhibition stuff * update to API modeling docs to reflect RAF neuronal cell --------- Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * remove unused local variables * update note * update model utils * remove notes * Update ode utils (#94) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * minor fix to header in diffeq * Update files with ode_solver (#95) * Update ode_utils.py merge ode_solver into ide_utils * Delete ngclearn/utils/diffeq/ode_solver.py * Update ode_utils.py refactor doc-string * Update odes.py * Update sindy.md ode_solver to ode_utils * revised/cleaned up sindy tutorial doc/imgs * add prior for hebbian patched synapse (#96) * prior replaced w_decay hebbianPatchedSynapse.py remove w_decay add prior_type and prior_lmbda * revised typo hebbianSynapse.py dWweight was typo * cleaned up doc-strings in odes.py to comply w/ ngc-learn format * minor tweak to sig-figs printing in probe utils * add-sigma-to-gaussianErrorCell (#97) * add-sigma-to-gaussianErrorCell add not updating scalar variance for gaussian errors * Update gaussianErrorCell.py * cleaned up ode_utils, cleaned up gaussian/laplacian cell * Update gaussianErrorCell.py (#98) added `and not isinstance(sigma, int)` * cleaned up gauss/laplace error cells * integrated bernoulli err-cell * Major release update merge to main (in prep for 2.0.0 release on release branch/pip) (#99) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked requirements in prep for major release * cleaned up a few unit tests to use deterministic syn init vals * mod to requirements * nudge toml to upcoming 2.0.0 * update to support docs in prep for 2.0.0 * update patched synapses and their test cases * cleaned up syn modeling doc * push hebbian synapse * push reinforce synapse * push np seed * patched minor prior None arg issue in hebb-syn * moved reinforce-syn to right spot * update reinforce synapse and testing * tweaked trace-stdp and mstdpet * patched mstdpet unit-test * update reinforce synapse and test cases * add reinforce synapse fix * minor mod to mstdpet * update test code for more than 1 steps * Updated monitors * patched tests to use process naming * Added wrapper for reset and advance_state * Added a JaxProcess Added Jax Process to allow for scanning over the process. * update the old rate cell * update old hebbian synapse * minor edit to if-cell * ported over adex tutorial to new ngclearn format * hh-cell supports rk4 integration * clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials * Update jaxProcess.py Updated the jax process to allow for more configurations of inputs. * update working reinforce synapse * update correct reinforce and testing * update documentation * update features, documentation, and testing * update testing for REINFORCE cell * update code and test * update code * add clipping gradient to model utils * update reinforce cell to the new model utils clip * update test cases --------- Co-authored-by: Viet Dung Nguyen Co-authored-by: Alexander Ororbia Co-authored-by: Will Gebhardt * Major release update (to 2.0.0) (#100) * add initial patch mask features * minor edit to bern-cell * fixed bernoulli error cell * example rate cell test * made some corrections to bern err-cell and heb syn * made some corrections to bern err-cell and heb syn * cleaned up bern-cell, hebb-syn * minor mod to model-utils * attempted rewrite of bernoulli-cell * got bernoulli-cell rewritten and unit-tested * edit to bern-cell * bernoulli and poisson cells revised, unit-tested * latency-cell refactored and unit-tested * refactored Rate Cell * minor revisions to input-encoders, revised phasor-cell w/ unit-test * revised and add unit-test for varTrace * revised and added unit-test for exp-kernel * revised and added unit-test for exp-kernel * revised slif cell w/ unit-test; needed mod to diffeq * revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib * revised lif-cell w/ unit-test * revised unit-tests to pass globally; some minor patches to phasor-cell and lif * minor cleanup of unit-test for phasor * revised if-cell w/ unit-test * revised if-cell w/ unit-test * revised quad-lif w/ unit-test * revised adex-cell w/ unit test, minor cleanup of quad-lif * minor edit to adex unit-test * refactor bernoulli, laplacian, and rewarderror cells * revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells * revised wtas-cell w/ unit test * revised fh-cell w/ unit test * revised izh-cell w/ unit test * patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell * update rate cell * fix test rate cell * update test for bernoulli cell * update refactoring for gaussian error cell * update unit testing for all graded neurons * wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere * added rk2 support for H-H cell * update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions * update test rate cell * refactored dense and trace-stdp syn w/ unit-test * refactored exp-stdp syn w/ unit-test * refactored event-stdp w/ unit-test * cleanup of stdp-syn * refactored bcm syn w/ unit-test * refactored stp-syn with unit-test * cleaned up modulated * refactored mstdp-et syn w/ unit-test * refactored lava components to new sim-lib * refactored conv/hebb-conv syn w/ unit-test * refactored/revised hebb-deconv syn w/ unit-test * revised/refactored hebb/stdp conv/deconv syn w/ unit-tests * updated modeling doc to point to hodgkin-huxley cell * updated modeling docs * fixed typo in adex-cell tutorial doc * revised tutorials to reflect new sim-lib config/syntax * revised tutorials to reflect new sim-lib config/syntax * patched docs to reflect revisions/refactor * tweaked… * minor edit to history * fixed to header/requirements * fixed to header/requirements * fixed to header/requirements * cleanup of ganglion cell * cleanup of metric util --------- Co-authored-by: Alexander Ororbia Co-authored-by: Alexander Ororbia Co-authored-by: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Co-authored-by: Will Gebhardt Co-authored-by: Viet Nguyen Co-authored-by: Viet Dung Nguyen Co-authored-by: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Co-authored-by: Anton Vice <118047001+antonvice@users.noreply.github.com> Co-authored-by: Copilot Co-authored-by: Sonny George <56851635+sonnygeorge@users.noreply.github.com> Co-authored-by: Alexander Ororbia Co-authored-by: Ankur Mali --- history.txt | 2 + ngclearn/__init__.py | 2 +- .../components/input_encoders/ganglionCell.py | 69 ++++++++++--------- ngclearn/utils/metric_utils.py | 4 ++ requirements.txt | 3 +- 5 files changed, 47 insertions(+), 33 deletions(-) diff --git a/history.txt b/history.txt index 4345a232..a8c1e724 100644 --- a/history.txt +++ b/history.txt @@ -99,4 +99,6 @@ History * new suite of competitive learning synapses (generalized formats), including vector-quantization, self-organizing map, adaptive resonance theory (contus-version), and modern hopfield network * revisions to metric and model utils * some additional clean-up, including supported retinal ganglion encoder + * fixed pkg-resource header (in both ngcsimlib and ngclearn) to account for newer python(s) + diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index 6121286c..5670cfca 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -31,7 +31,7 @@ "currently installed!") ################################################################################## -## Needed to preload is called before anything in ngclearn +## Following are needed to preload is called before anything in ngclearn from pathlib import Path from sys import argv import numpy diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index a35c7b0a..f3b2a41e 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -70,7 +70,10 @@ def _create_patches(obs, patch_shape, step_shape): class RetinalGanglionCell(JaxComponent): """ - A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain. + A group of retinal ganglion cell that sense input stimuli and send out filtered + signals (as output). Note that these simulated cells employ internal generalized + filters based on either Gaussian or difference-of-Gaussian kernels) to recover + historical receptive field processing effects. | --- Cell Input Compartments: --- | inputs - input (takes in external signals) @@ -85,32 +88,34 @@ class RetinalGanglionCell(JaxComponent): filter_type: string name of filter function (Default: identity) :Note: supported filters include "gaussian", "difference_of_gaussian" - sigma: standard deviation of gaussian kernel + sigma: standard deviation of (gaussian) kernel - area_shape: receptive field area of ganglion cells in this module all together + area_shape: shape of receptive field area of ganglion cells in this module (all together) n_cells: number of ganglion cells in this module - patch_shape: each ganglion cell receptive field area + patch_shape: shape of each ganglion cell's receptive field area - step_shape: the non-overlapping area between each two ganglion cells + step_shape: the non-overlapping area between each pair (two) of ganglion cells - batch_size: batch size dimension of this cell (Default: 1) + batch_size: batch size dimension of this cell/module (Default: 1) """ - def __init__(self, name: str, - filter_type: str, - area_shape: Tuple[int, int], - n_cells: int, - patch_shape: Tuple[int, int], - step_shape: Tuple[int, int], - batch_size: int = 1, - sigma: float = 1.0, - key: Union[jax.Array, None] = None, - **kwargs): + def __init__( + self, + name: str, + filter_type: str, + area_shape: Tuple[int, int], + n_cells: int, + patch_shape: Tuple[int, int], + step_shape: Tuple[int, int], + batch_size: int = 1, + sigma: float = 1.0, + key: Union[jax.Array, None] = None, + **kwargs + ): super().__init__(name=name, key=key) - ## Layer Size Setup self.filter_type = filter_type self.n_cells = n_cells @@ -143,14 +148,14 @@ def __init__(self, name: str, @compilable def advance_state(self, t): inputs = self.inputs.get() - filter = self.filter.get() + _filter = self.filter.get() px, py = self.patch_shape # ═══════════════════ extract pathches for filters ══════════════════ input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape) # ═══════════════════ apply filter to all pathches ══════════════════ - filtered_input = input_patches * filter ## shape: (B | n_cells | px | py) + filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py) # ════════════ reshape all cells responses to a single input to brain ════════════ filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py) @@ -184,7 +189,7 @@ def reset(self): ## reset core components/statistics @classmethod def help(cls): ## component help function properties = { - "cell_type": "RetinalGanglionCell - filters the input stimuli, " + "cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics" } compartment_props = { "inputs": @@ -196,11 +201,11 @@ def help(cls): ## component help function } hyperparams = { "filter_type": "Type of the filter for preprocessing the input", - "sigma": "Standard deviation of gaussian kernel", + "sigma": "Standard deviation of gaussian kernel/filter", "area_shape": "Effective receptive field area shape of ganglion cells in this module", - "n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer", - "patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module", - "step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module", + "n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer", + "patch_shape": "Classical receptive field area shape of individual ganglion cells in this module", + "step_shape": "Extra-classical receptive field area shape each ganglion cell in this module", "batch_size": "Batch size dimension of this component" } info = {cls.__name__: properties, @@ -212,13 +217,15 @@ def help(cls): ## component help function if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: - X = RetinalGanglionCell("RGC", filter_type="gaussian", - sigma=2.3, - area_shape=(16, 26), - n_cells = 3, - patch_shape=(16, 16), - step_shape=(0, 5) - ) + X = RetinalGanglionCell( + "RGC", + filter_type="gaussian", + sigma=2.3, + area_shape=(16, 26), + n_cells = 3, + patch_shape=(16, 16), + step_shape=(0, 5) + ) print(X) diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index 457043cb..ae3c6230 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -89,9 +89,13 @@ def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=Fals this matrix is a non-negative vector. Formally, this means we compute, per i-th row: + | rho(x_i) = num_zeros(x_i) / dim(x_i) + and for a global score for matrix X with N codes/rows, we measure: + | rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i) + where lower/closer to 0 means codes more sparse and closer to 1 means codes are more dense. diff --git a/requirements.txt b/requirements.txt index 09ab6214..978bd1fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ matplotlib>=3.9.4 # patchify # patchify has issues with pip installation jax>=0.4.28 jaxlib>=0.4.28 -ngcsimlib>=3.0.0 +ngcsimlib>=3.1.0 imageio>=2.37.0 pandas>=2.2.3 +typing_extensions>=4.15.0 From 6feae7f7eabf7afdc40a86273c008cbe3a61c461 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 1 May 2026 13:12:26 -0400 Subject: [PATCH 02/23] minor cleanup of filter variable name in ganglion-encoder --- .../components/input_encoders/ganglionCell.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index f3b2a41e..b7f3bf7a 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -126,23 +126,22 @@ def __init__( self.patch_shape = patch_shape self.step_shape = step_shape - filter = jnp.ones(self.patch_shape) - + _filter = jnp.ones(self.patch_shape) if filter_type == 'gaussian': - filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) + _filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) elif filter_type == 'difference_of_gaussian': - filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) + _filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) # ═════════════════ compartments initial values ════════════════════ - in_restVals = jnp.zeros((batch_size, - *self.area_shape)) ## input: (B | ix | iy) + in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) - self.n_cells * self.patch_shape[0] * self.patch_shape[1])) + out_restVals = jnp.zeros( + (batch_size, self.n_cells * self.patch_shape[0] * self.patch_shape[1]) + ) ## output.shape: (B | n_cells * px * py) # ═══════════════════ set compartments ══════════════════════ self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment - self.filter = Compartment(filter, display_name="Filter") # Filter compartment + self.filter = Compartment(_filter, display_name="Filter") # Filter compartment self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment @compilable From 244ceab664e193c8d8f35e6747d139e985af3c12 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 1 May 2026 20:27:07 -0400 Subject: [PATCH 03/23] update batch size setup for elestic net, lasso, and ridge Co-authored-by: Copilot --- ngclearn/modules/regression/elastic_net.py | 6 ++++-- ngclearn/modules/regression/lasso.py | 6 ++++-- ngclearn/modules/regression/ridge.py | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ngclearn/modules/regression/elastic_net.py b/ngclearn/modules/regression/elastic_net.py index 5860d2bc..06cd09bd 100644 --- a/ngclearn/modules/regression/elastic_net.py +++ b/ngclearn/modules/regression/elastic_net.py @@ -77,13 +77,15 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0., - optim_type=optim_type, key=subkeys[0] + optim_type=optim_type, key=subkeys[0], batch_size=batch_size ) - self.err = GaussianErrorCell("err", n_units=sys_dim) + self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size + self.W.reset() + self.err.reset() # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post diff --git a/ngclearn/modules/regression/lasso.py b/ngclearn/modules/regression/lasso.py index 15a014bb..8d5f91d5 100644 --- a/ngclearn/modules/regression/lasso.py +++ b/ngclearn/modules/regression/lasso.py @@ -76,12 +76,14 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('lasso', lasso_lmbda), w_bound=0., - optim_type=optim_type, key=subkeys[0] + optim_type=optim_type, key=subkeys[0], batch_size=batch_size ) - self.err = GaussianErrorCell("err", n_units=sys_dim) + self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size + self.W.reset() + self.err.reset() # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post diff --git a/ngclearn/modules/regression/ridge.py b/ngclearn/modules/regression/ridge.py index dfbacb03..84591670 100644 --- a/ngclearn/modules/regression/ridge.py +++ b/ngclearn/modules/regression/ridge.py @@ -76,13 +76,15 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('ridge', ridge_lmbda), w_bound=0., - optim_type=optim_type, key=subkeys[0] + optim_type=optim_type, key=subkeys[0], batch_size=batch_size ) - self.err = GaussianErrorCell("err", n_units=sys_dim) + self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size + self.W.reset() + self.err.reset() # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post From a105cdaa6527ce8dd839c8c3c4416decffc5addd Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 2 May 2026 16:43:30 -0400 Subject: [PATCH 04/23] integrated gini-index into metric-utils --- ngclearn/utils/metric_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index ae3c6230..16c32c15 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -82,6 +82,22 @@ def measure_breadth_TC(spikes, preserve_batch=False): BTC = jnp.mean(BTC) return BTC +@partial(jit, static_argnums=[1]) +def measure_gini_index(codes, preserve_batch=True): + ## Gini index + ### values closer to 1 indicate high sparsity (sparser codes) + ### values closer to 0 indicate lower sparsity (denser codes) + ## calculation requires sorted array for O(n) or O(nlogn) + D = codes.shape[1] ## length of vector + codes_sorted = jnp.sort(jnp.abs(codes), axis=1) ## sort all codes w/in batch matrix + index = jnp.arange(1, D + 1) + term1 = jnp.sum((2 * index - D - 1) * codes_sorted, axis=1, keepdims=True) + term2 = D * jnp.sum(codes_sorted, axis=1, keepdims=True) + gini = term1 / term2 + if not preserve_batch: + gini = jnp.mean(gini) + return gini + @partial(jit, static_argnums=[2, 3]) def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=False): """ From 788d7a82945996e44c2a3b4a1e942b726ba643ca Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 2 May 2026 17:18:43 -0400 Subject: [PATCH 05/23] integrated gini-index into metric-utils --- ngclearn/utils/metric_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index 16c32c15..cc35a197 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -87,15 +87,18 @@ def measure_gini_index(codes, preserve_batch=True): ## Gini index ### values closer to 1 indicate high sparsity (sparser codes) ### values closer to 0 indicate lower sparsity (denser codes) - ## calculation requires sorted array for O(n) or O(nlogn) + + ### note that the calculation below is faster than the mean-absolute-value + ### form of gini-index; below calculation requires sorting but yields a + ### lower-complexity calculation D = codes.shape[1] ## length of vector codes_sorted = jnp.sort(jnp.abs(codes), axis=1) ## sort all codes w/in batch matrix index = jnp.arange(1, D + 1) term1 = jnp.sum((2 * index - D - 1) * codes_sorted, axis=1, keepdims=True) term2 = D * jnp.sum(codes_sorted, axis=1, keepdims=True) - gini = term1 / term2 + gini = term1 / term2 ## calc final ratio if not preserve_batch: - gini = jnp.mean(gini) + gini = jnp.mean(gini) ## this is the mean gini-index return gini @partial(jit, static_argnums=[2, 3]) From abe28f8a7ca16b3399890bd7894fe0e635717198 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 2 May 2026 17:34:16 -0400 Subject: [PATCH 06/23] integrated gini-index into metric-utils --- ngclearn/utils/metric_utils.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index cc35a197..bab15882 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -67,12 +67,12 @@ def measure_breadth_TC(spikes, preserve_batch=False): spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster - preserve_batch: if True, will return one score per sample in batch + preserve_batch: if True, will return one score per neuron in train/window (Default: False), otherwise, returns scalar average score Returns: - a 1 x D Fano factor vector (one factor per neuron) OR a single - average Fano factor across the neuronal group + a 1 x D BTC vector (one factor per neuron) OR a single + average BTC across the neuronal group """ mu = jnp.mean(spikes, axis=0, keepdims=True) sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True)) @@ -84,6 +84,23 @@ def measure_breadth_TC(spikes, preserve_batch=False): @partial(jit, static_argnums=[1]) def measure_gini_index(codes, preserve_batch=True): + """ + Calculates the gini index a group of neurons represented as vector code samples. + Gini index measures the sparseness of the values within each vector code, where + a higher index value indicates higher sparsity and a lower index value indicates a + lower sparsity (higher density). + + Args: + codes: a batch of neural codes; shape is (N x D) where D is number of + neurons in a group/cluster and N is number of samples + + preserve_batch: if True, will return one score per sample in batch + (Default: False), otherwise, returns scalar average score + + Returns: + a N x 1 Gini index vector (one score per neuron) OR a single + average Gini score for the whole sample/set of codes + """ ## Gini index ### values closer to 1 indicate high sparsity (sparser codes) ### values closer to 0 indicate lower sparsity (denser codes) From 1f3883ab6c53908749756ea7f721fb712eff24ea Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 5 May 2026 19:10:49 -0400 Subject: [PATCH 07/23] minor cleanup in utils --- ngclearn/components/synapses/denseSynapse.py | 2 +- ngclearn/utils/model_utils.py | 17 ++++++ ngclearn/utils/viz/raster.py | 1 + ngclearn/utils/viz/synapse_plot.py | 57 +++++++++++++++----- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index ee1ecb02..92c9c5e9 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -83,7 +83,7 @@ def __init__( self.inputs = Compartment(preVals) self.outputs = Compartment(postVals) self.weights = Compartment(weights) - _mask = 1. + _mask = jnp.ones((1, 1)) if mask is not None: _mask = mask self.mask = Compartment(_mask) diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index bf20079b..9f22cc76 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -770,6 +770,23 @@ def create_block_matrix(map_matrix, group_shape, alpha_inh=-1., alpha_exc=1.): gmat = jnp.concatenate(gmat, axis=0) return gmat +@partial(jit, static_argnums=[0]) +def eye_wrapped(N, k, values): + """ + Creates an N x N matrix with a wrapped off-diagonal. + + Args: + N: Size of the square matrix (N x N) + + k: Diagonal offset (positive=above, negative=below) + + values: Array of values to place (length should match n) + """ + matrix = jnp.zeros((N, N)) ## Create empty matrix + row_indices = jnp.arange(N) ## Generate indices for the diagonal + col_indices = (row_indices + k) % N ## Wrap column indices using modulo + return matrix.at[row_indices, col_indices].set(values) ## Fill diagonal using efficient indexing + @partial(jit, static_argnums=[1, 2, 3, 4]) def normalize_block_matrix(matrix, block_size, order=2, axis=0, norm_targ=1.): """ diff --git a/ngclearn/utils/viz/raster.py b/ngclearn/utils/viz/raster.py index dff8745b..875a52e8 100755 --- a/ngclearn/utils/viz/raster.py +++ b/ngclearn/utils/viz/raster.py @@ -140,3 +140,4 @@ def create_overlay_raster_plot(spike_train, targ_train, Y, idxs, s=1.5, c="black plt.savefig(plot_fname + '_' + str(idx) + suffix) plt.clf() plt.close() + diff --git a/ngclearn/utils/viz/synapse_plot.py b/ngclearn/utils/viz/synapse_plot.py index 0f268491..60c9ccf5 100644 --- a/ngclearn/utils/viz/synapse_plot.py +++ b/ngclearn/utils/viz/synapse_plot.py @@ -9,8 +9,13 @@ import jax.numpy as jnp - -def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): +def visualize( + thetas, + sizes, + prefix, + order=None, + suffix='.jpg' +): """ Args: @@ -22,8 +27,6 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): suffix: """ - - if order is None: order = ['C' for _ in range(len(thetas))] @@ -62,7 +65,14 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): plt.close() -def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffix='.jpg'): +def visualize_labels( + thetas, + sizes, + prefix, + space_width=None, + widths=None, + suffix='.jpg' +): """ Args: @@ -139,14 +149,34 @@ def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffi fig.savefig(prefix+suffix, bbox_inches='tight') plt.close(fig) -def visualize_frame(frame, path='.', name='tmp', suffix='.jpg', **kwargs): +def visualize_frame( + frame, + path='.', + name='tmp', + suffix='.jpg', + **kwargs +): iio.imwrite(path + '/' + name + suffix, frame.astype(jnp.uint8), **kwargs) -def visualize_gif(frames, path='.', name='tmp', suffix='.jpg', **kwargs): +def visualize_gif( + frames, + path='.', + name='tmp', + suffix='.jpg', + **kwargs +): _frames = [f.astype(jnp.uint8) for f in frames] iio.imwrite(path + '/' + name + '.gif', _frames, **kwargs) -def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs): +def make_video( + f_start, + f_end, + path, + prefix, + suffix='.jpg', + skip=1, + **kwargs +): images = [] for i in range(f_start, f_end+1, skip): print("Reading frame " + str(i)) @@ -154,10 +184,13 @@ def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs): print("writing gif") iio.imwrite(path + '/training.gif', images, **kwargs) - -# def visualize_norm(thetas, sizes, prefix, suffix='.jpg'): - -def viz_block(thetas, sizes, prefix, suffix=".jpg", padding=1, low_rez=True): +def viz_block( + thetas, + sizes, prefix, + suffix=".jpg", + padding=1, + low_rez=True +): num_filters = [T.shape[1] for T in thetas] n_cols = [math.ceil(math.sqrt(nf)) for nf in num_filters] n_rows = [math.ceil(nf / c) for nf, c in zip(num_filters, n_cols)] From 2f5f885645ede5e04459a931baa61b370a325f45 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 15 May 2026 19:35:11 -0400 Subject: [PATCH 08/23] updates to utils/metrics/filters --- .../components/input_encoders/ganglionCell.py | 10 +-- .../components/synapses/hebbian/__init__.py | 1 + ngclearn/utils/filters/__init__.py | 1 + ngclearn/utils/filters/gauss_filter.py | 66 +++++++++++++++++++ ngclearn/utils/metric_utils.py | 4 +- ngclearn/utils/viz/synapse_plot.py | 7 +- 6 files changed, 77 insertions(+), 12 deletions(-) create mode 100644 ngclearn/utils/filters/__init__.py create mode 100644 ngclearn/utils/filters/gauss_filter.py diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index b7f3bf7a..562b6b5d 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -8,24 +8,18 @@ def _create_gaussian_filter(patch_shape, sigma): ## Create a 2D Gaussian kernel centered on patch_shape with given sigma. px, py = patch_shape - x_ = jnp.linspace(0, px - 1, px) y_ = jnp.linspace(0, py - 1, py) - x, y = jnp.meshgrid(x_, y_) - xc = px // 2 yc = py // 2 - - filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2))) - return filter / jnp.sum(filter) + _filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2))) + return _filter / jnp.sum(_filter) def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1): g1 = _create_gaussian_filter(patch_shape, sigma=sigma) g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k) - dog = g1 - lmbda * g2 - return dog #- jnp.mean(dog) def _create_patches(obs, patch_shape, step_shape): diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index 05bfd207..0a1630c3 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -4,4 +4,5 @@ from .expSTDPSynapse import ExpSTDPSynapse from .eventSTDPSynapse import EventSTDPSynapse from .BCMSynapse import BCMSynapse +from .gerstnerHebbianSynapse import GerstnerHebbianSynapse ## Taylor-expansion Hebbian model diff --git a/ngclearn/utils/filters/__init__.py b/ngclearn/utils/filters/__init__.py new file mode 100644 index 00000000..0dc711b9 --- /dev/null +++ b/ngclearn/utils/filters/__init__.py @@ -0,0 +1 @@ +from .gauss_filter import gaussian_filter diff --git a/ngclearn/utils/filters/gauss_filter.py b/ngclearn/utils/filters/gauss_filter.py new file mode 100644 index 00000000..f080883f --- /dev/null +++ b/ngclearn/utils/filters/gauss_filter.py @@ -0,0 +1,66 @@ +import jax.numpy as jnp +from jax import lax, jit +from functools import partial + + +def _calc_gaussian_kernel_2D( ## internal co-routine + sigma: float, + radius: int +) -> jnp.ndarray: + ## Generate a (normalized) 2D Gaussian kernel of shape: (1, 1, 2*radius+1, 2*radius+1) + x = jnp.arange(-radius, radius + 1) + xx, yy = jnp.meshgrid(x, x) + kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2) + kernel = kernel / jnp.sum(kernel) + # Reshape to (out_channels, in_channels, height, width) for lax.conv_general_dilated + return kernel[jnp.newaxis, jnp.newaxis, :, :] + +@partial(jit, static_argnums=[3, 4]) +def gaussian_filter( + images: jnp.ndarray, ## input image batch + sigma_center: float, ## sigma1 + sigma_surround: float, ## sigma2 + kernel_size : int, ## radius + use_ratio=False ## if True, this becomes a ratio-of-Gaussians +) -> jnp.ndarray: + """ + Applies a difference-of-Gaussians filter to a batch of 2D images (of CxHxW tensor shape). + + Args: + images: Input image tensor of shape (B, C, H, W) + + sigma_center: Standard deviation for narrow / center blur + + sigma_surround: Standard deviation for wide / surround blur + + kernel_size: Kernel radius (window size will be `2*radius + 1`) + + use_ratio: whether or not to use a ratio-of-Gaussians filter (Default: False) + + Returns: + An output tensor of shape (B, C, H, W) + """ + x = images + ## Construct two 2D Gaussian kernels + k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) ## center kernel + k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) ## surround kernel + ## Define dimension ordering for lax.conv ('NCHW' standard layout) + dn = lax.ConvDimensionNumbers( + lhs_spec=(0, 1, 2, 3), ## (batch, channel, height, width) + rhs_spec=(0, 1, 2, 3), ## (out_channel, in_channel, height, width) + out_spec=(0, 1, 2, 3) ## (batch, channel, height, width) + ) + ## Perform spatial convolutions w/ edge padding to emulate 'SAME' behavior + blur_center = lax.conv_general_dilated( + x, k1, window_strides=(1, 1), padding=[(kernel_size, kernel_size), (kernel_size, kernel_size)], dimension_numbers=dn + ) + blur_surround = lax.conv_general_dilated( + x, k2, window_strides=(1, 1), padding=[(kernel_size, kernel_size), (kernel_size, kernel_size)], dimension_numbers=dn + ) + ## Perform final filter calculation + if use_ratio: + eps = 1e-5 + output = blur_center / (blur_surround + eps) ## Compute kernel difference + else: + output = blur_center - blur_surround ## Compute kernel ratio + return output ## shape: (B, C, H, W) diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index bab15882..15ad7a0d 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -104,12 +104,12 @@ def measure_gini_index(codes, preserve_batch=True): ## Gini index ### values closer to 1 indicate high sparsity (sparser codes) ### values closer to 0 indicate lower sparsity (denser codes) - + _codes = codes + (jnp.sum(codes, axis=1, keepdims=True) <= 0.) + 1e-8 ### note that the calculation below is faster than the mean-absolute-value ### form of gini-index; below calculation requires sorting but yields a ### lower-complexity calculation D = codes.shape[1] ## length of vector - codes_sorted = jnp.sort(jnp.abs(codes), axis=1) ## sort all codes w/in batch matrix + codes_sorted = jnp.sort(jnp.abs(_codes), axis=1) ## sort all codes w/in batch matrix index = jnp.arange(1, D + 1) term1 = jnp.sum((2 * index - D - 1) * codes_sorted, axis=1, keepdims=True) term2 = D * jnp.sum(codes_sorted, axis=1, keepdims=True) diff --git a/ngclearn/utils/viz/synapse_plot.py b/ngclearn/utils/viz/synapse_plot.py index 60c9ccf5..f62a47da 100644 --- a/ngclearn/utils/viz/synapse_plot.py +++ b/ngclearn/utils/viz/synapse_plot.py @@ -55,8 +55,11 @@ def visualize( point = start + 1 + i + (r * extra) plt.subplot(n_rows_total, n_cols_total, point) - filter = T[i, :] - plt.imshow(np.reshape(filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]), cmap=plt.cm.bone, interpolation='nearest') + _filter = T[i, :] + plt.imshow( + np.reshape(_filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]), + cmap=plt.cm.bone, interpolation='nearest' + ) plt.axis("off") plt.subplots_adjust(top=0.9) From 1e3073a8c73ee52cb1d5d012616b7807a3113426 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 18 May 2026 22:37:14 -0400 Subject: [PATCH 09/23] added heatmap tool for utils.viz.classification_analysis --- ngclearn/utils/io_utils.py | 1 + ngclearn/utils/metric_utils.py | 4 +- ngclearn/utils/viz/__init__.py | 2 + ngclearn/utils/viz/classification_analysis.py | 56 +++++++++++++++++++ 4 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 ngclearn/utils/viz/classification_analysis.py diff --git a/ngclearn/utils/io_utils.py b/ngclearn/utils/io_utils.py index 8553af44..5365b34e 100755 --- a/ngclearn/utils/io_utils.py +++ b/ngclearn/utils/io_utils.py @@ -78,3 +78,4 @@ def load_pkl(directory: str, name: str) -> Any: with open(file_name, 'rb') as f: data = pickle.load(f) return data + diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index 15ad7a0d..37187bed 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -176,9 +176,9 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st y: target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with "extract_label_indx = True") - extract_label_indx: run an argmax to pull class integer indices from + extract_label_indx: wehn True, run an argmax to pull class integer indices from "y", assuming y is a one-hot binary encoding matrix (Default: True), - otherwise, this assumes "y" is an array of class integer indices + otherwise, if False, this treats "y" is an array of class integer indices of length N Returns: diff --git a/ngclearn/utils/viz/__init__.py b/ngclearn/utils/viz/__init__.py index e37c46fc..605d9bef 100644 --- a/ngclearn/utils/viz/__init__.py +++ b/ngclearn/utils/viz/__init__.py @@ -2,3 +2,5 @@ from . import raster from . import spike_plot from . import synapse_plot +from . import classification_analysis + diff --git a/ngclearn/utils/viz/classification_analysis.py b/ngclearn/utils/viz/classification_analysis.py new file mode 100644 index 00000000..c49a3d0f --- /dev/null +++ b/ngclearn/utils/viz/classification_analysis.py @@ -0,0 +1,56 @@ +import matplotlib.pyplot as plt +import numpy as np + +def visualize_confusion_heatmap( + confuse_matrix, + classes, + out_fname, + figure_title="Confusion Matrix", + color_map="Blues", # "Greens" "Reds" + fontsize=10, + show_percent=True +): + _conf = confuse_matrix + pscale = 1. ## percentage scale + if show_percent: + _conf = (_conf / np.sum(_conf, axis=1, keepdims=True)) * 100 + ## Initialize plot + fig, ax = plt.subplots(figsize=(6, 6)) + vmin = np.floor(np.min(_conf)) + vmax = np.ceil(np.max(_conf)) + im = ax.imshow( + _conf, interpolation="nearest", cmap=color_map, vmin=vmin, vmax=vmax + ) + ## Add color bar + fig.colorbar(im, ax=ax, shrink=0.75) + ## Configure axes labels + ax.set_xticks(np.arange(len(classes))) + ax.set_yticks(np.arange(len(classes))) + ax.set_xticklabels(classes) + ax.set_yticklabels(classes) + + ax.set_xlabel("Predicted Label", fontsize=12, labelpad=10) + ax.set_ylabel("True Label", fontsize=12, labelpad=10) + ax.set_title(figure_title, fontsize=14, pad=15) + + ## Loop to print numbers on top of each colored cell + threshold = _conf.max() / 2.0 # Find midpoint to flip text color for readability + for i in range(_conf.shape[0]): + for j in range(_conf.shape[1]): + ## Use white text on dark cells, black text on light cells + color = "white" if _conf[i, j] > threshold else "black" + ax.text( + j, + i, + f"{_conf[i, j]:.1f}", #format(confuse_matrix[i, j], "d"), + ha="center", + va="center", + color=color, + fontsize=fontsize, + weight="bold", + ) + + ## Ensure layout fits tightly + plt.tight_layout() + plt.savefig(out_fname) + From 22ba7aaee944dd69b8e523296c8b54c9a318f8e1 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 19 May 2026 12:52:55 -0400 Subject: [PATCH 10/23] added class-conformity metrics in metric_utils; integrated kmeans-probe in utils.analysis --- ngclearn/utils/analysis/__init__.py | 2 +- ngclearn/utils/analysis/kmeans_probe.py | 109 +++++++++ ngclearn/utils/analysis/knn_probe.py | 44 ++-- ngclearn/utils/metric_utils.py | 227 ++++++++++++++++++ ngclearn/utils/viz/classification_analysis.py | 7 +- 5 files changed, 363 insertions(+), 26 deletions(-) create mode 100644 ngclearn/utils/analysis/kmeans_probe.py diff --git a/ngclearn/utils/analysis/__init__.py b/ngclearn/utils/analysis/__init__.py index 082c68cd..31506c3e 100644 --- a/ngclearn/utils/analysis/__init__.py +++ b/ngclearn/utils/analysis/__init__.py @@ -2,4 +2,4 @@ from .linear_probe import LinearProbe from .attentive_probe import AttentiveProbe from .knn_probe import KNNProbe - +from .kmeans_probe import KMeansProbe diff --git a/ngclearn/utils/analysis/kmeans_probe.py b/ngclearn/utils/analysis/kmeans_probe.py new file mode 100644 index 00000000..676a00cd --- /dev/null +++ b/ngclearn/utils/analysis/kmeans_probe.py @@ -0,0 +1,109 @@ +import jax +from ngcsimlib import deprecate_args +from ngclearn.utils.analysis.probe import Probe +from jax import jit, random, numpy as jnp, lax, nn +from functools import partial as bind +from ngclearn.utils.metric_utils import measure_ARI + +@bind(jax.jit, static_argnums=[2]) +def _run_kmeans_probe(_embeddings, centroids, n_clusters): + ## Broadcast distances: (n_samples, 1, n_features) - (1, n_clusters, n_features) + distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1) + labels_pred = jnp.argmin(distances, axis=1) + ## Re-estimate centroids/means + one_hot_preds = labels_pred[:, None] == jnp.arange(n_clusters) + counts = jnp.maximum(one_hot_preds.sum(axis=0, keepdims=True).T, 1.0) + centroids = jnp.dot(one_hot_preds.T.astype(jnp.float32), _embeddings) / counts + return centroids + +@bind(jax.jit, static_argnums=[2]) +def _predict_with_probe(_embeddings, centroids, n_clusters): + ## Final pass to compute stable predictions + distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1) + labels_pred = jnp.argmin(distances, axis=1) + Y_pred = nn.one_hot(labels_pred, n_clusters) + return labels_pred, Y_pred + +class KMeansProbe(Probe): + """ + This implements a K-means clustering probe, which is useful for evaluating the quality of + encodings/embeddings in light of the ability to cluster downstream data. Currently, this + probe only supports L2/Euclidean distance-based clustering. + + Args: + dkey: init seed key + + source_seq_length: length of input sequence (e.g., height x width of the image feature) + + input_dim: input dimensionality of probe + + out_dim: output dimensionality of probe - number of clusters for this probe to create + + batch_size: + + """ + + def __init__( + self, + dkey, + source_seq_length, + input_dim, + out_dim=2, ## number of clusters/centroids to uncover + batch_size=1, + **kwargs + ): + super().__init__(dkey, batch_size, **kwargs) + self.dkey, *subkeys = random.split(self.dkey, 3) + self.source_seq_length = source_seq_length + self.input_dim = input_dim + self.n_clusters = self.out_dim = out_dim + ## centroids that will be uncovered by this probe + self.centroids : jax.Array = None + + def _init(self, embeddings): + _embeddings = embeddings + if len(_embeddings.shape) > 2: + flat_dim = embeddings.shape[1] * embeddings.shape[2] + _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) + ## choose random data-points to serve as centroids at iteration 0 + self.dkey, *subkeys = random.split(self.dkey, 15) + n_samples, n_features = _embeddings.shape + random_indices = random.choice( + subkeys[0], n_samples, shape=(self.n_clusters,), replace=False + ) + self.centroids = _embeddings[random_indices] + + def process(self, embeddings, dkey=None): + _embeddings = embeddings + if len(_embeddings.shape) > 2: + flat_dim = embeddings.shape[1] * embeddings.shape[2] + _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) + ## Compute final geometric vs semantic conformity via ARI + _, Y_pred = _predict_with_probe(_embeddings, self.centroids, self.n_clusters) + return Y_pred ## (B, C) + + def update(self, embeddings, labels, dkey=None): + _embeddings = embeddings + if len(_embeddings.shape) > 2: + flat_dim = embeddings.shape[1] * embeddings.shape[2] + _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) + self.centroids = _run_kmeans_probe(_embeddings, self.centroids, self.n_clusters) + L = 0. ## FIXME: should be clustering loss + predictions = self.process(_embeddings) + return L, predictions + + def fit(self, dataset, dev_dataset=None, n_iter=20, patience=20): + data, labels = dataset + _labels = jnp.argmax(labels, axis=-1) + + self._init(data) ## init K-means centroids + ari = 0. + for i in range(n_iter): ## Run vectorized K-Means optimization loop + _L, py = self.update(data, labels) + labels_pred = jnp.argmax(py, axis=1) + ari_i = measure_ARI(_labels, labels_pred) + print(f"\r{i}: ARI = {ari_i}", end="") + if ari_i > ari: + ari = ari_i + print() + return ari diff --git a/ngclearn/utils/analysis/knn_probe.py b/ngclearn/utils/analysis/knn_probe.py index 6935a3a0..54f061c9 100644 --- a/ngclearn/utils/analysis/knn_probe.py +++ b/ngclearn/utils/analysis/knn_probe.py @@ -139,25 +139,25 @@ def update(self, embeddings, labels, dkey=None): Wy = labels self.probe_params = (Wx, Wy) -if __name__ == '__main__': - seed = 42 - D = 7 - C = 5 - dkey = random.PRNGKey(seed) - dkey, *subkeys = random.split(dkey, 3) - knn = KNNProbe( - subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean" - ) - X = random.uniform(subkeys[1], shape=(10, D)) - Y = jnp.concat( - [ - jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]]) - ], - axis=0 - ) - knn.update(X, Y) ## fit KNN to data - print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y +# if __name__ == '__main__': +# seed = 42 +# D = 7 +# C = 5 +# dkey = random.PRNGKey(seed) +# dkey, *subkeys = random.split(dkey, 3) +# knn = KNNProbe( +# subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean" +# ) +# X = random.uniform(subkeys[1], shape=(10, D)) +# Y = jnp.concat( +# [ +# jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]]) +# ], +# axis=0 +# ) +# knn.update(X, Y) ## fit KNN to data +# print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index 37187bed..38100611 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -463,3 +463,230 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 if not preserve_batch: bce = jnp.mean(bce) return bce + + +@partial(jit, static_argnums=[2, 3]) +def _compute_contingency_table( ## vectorized construction of contingency matrix + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int +) -> jnp.ndarray: + ## Computes a contingency matrix table + ## This routine expects true integer labels and predicted integer labels (1D arrays of size N) + + # Create indicator masks across all unique classes/clusters + # find unique IDs safely up to a static maximum size (or provide num_classes) + # n_classes = n_true = jnp.max(labels_true) + 1 + # n_clusters = n_pred = jnp.max(labels_pred) + 1 + + # Broadcast to form a full one-hot lookup map + true_mask = labels_true[:, None] == jnp.arange(n_classes) + pred_mask = labels_pred[:, None] == jnp.arange(n_clusters) + + # Contingency matrix is the matrix product of boolean indicators + contingency = jnp.dot(true_mask.T.astype(jnp.float32), pred_mask.astype(jnp.float32)) + return contingency + + +def measure_ARI( + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray +) -> jnp.ndarray: + """ + Computes the adjusted random index (ARI), which measures similarity between two + sets of indices (ground truth against a clustering's produced indices) via counting the + pairs of data points assigned to same or different clusters (adjusted for chance). This + measurement lies in `[0, 1]`, where `0` indicates a random labeling/assignment and `1` indicates + perfect agreement. + + Args: + labels_true: 1D array of shape (n_samples,) with true integer class labels. + + labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels. + + Returns: + scalar ARI of these two sets of indices + """ + ## Dynamically find dimensions up to a statically bounded maximum + n_classes = int(jnp.max(labels_true) + 1) + n_clusters = int(jnp.max(labels_pred) + 1) + return _calc_adjusted_rand_index(labels_true, labels_pred, n_classes, n_clusters) + + +@partial(jit, static_argnums=[2, 3]) +def _calc_adjusted_rand_index( ## ARI + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int +) -> jnp.ndarray: + n_samples = labels_true.shape[0] + if n_samples <= 1: + return jnp.array(1.0) + + ## Get contingency matrix (n_classes x n_clusters) + contingency = _compute_contingency_table( + labels_true, + labels_pred, + n_classes, + n_clusters + ) + + ## Calculate combination sums n_ijC2 = (n_ij * (n_ij - 1)) / 2 + sum_nij_c2 = jnp.sum((contingency * (contingency - 1.0)) / 2.0) + + ## Sums across margins (rows and columns) + sum_a = jnp.sum(contingency, axis=1) + sum_b = jnp.sum(contingency, axis=0) + + ## Margin pair combinations + sum_a_c2 = jnp.sum((sum_a * (sum_a - 1.0)) / 2.0) + sum_b_c2 = jnp.sum((sum_b * (sum_b - 1.0)) / 2.0) + + ## Expected index and Max index math formulas + total_c2 = (n_samples * (n_samples - 1.0)) / 2.0 + expected_index = (sum_a_c2 * sum_b_c2) / total_c2 + max_index = (sum_a_c2 + sum_b_c2) / 2.0 + + ## Prevent division by zero if everything is perfectly clustered or uniform + denominator = max_index - expected_index + ari = jnp.where(denominator == 0.0, 1.0, (sum_nij_c2 - expected_index) / denominator) + return ari + + +def measure_FMI( + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray +) -> jnp.ndarray: + """ + Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of + indices - this score is the geometric mean of pair-wise recall and precision. + This measurement lies in `[0, 1]`, where higher is better (indicating greater similarity between + two clustering sets of identifiers). + + Args: + labels_true: 1D array of shape (n_samples,) with true integer class labels. + + labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels. + + Returns: + scalar FMI of these two sets of indices + """ + ## Dynamically find dimensions up to a statically bounded maximum + n_classes = int(jnp.max(labels_true) + 1) + n_clusters = int(jnp.max(labels_pred) + 1) + return _measure_fowlkes_mallows_index(labels_true, labels_pred, n_classes, n_clusters) + + +@partial(jit, static_argnums=[2, 3]) +def _measure_fowlkes_mallows_index( ## FMI + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int +) -> jnp.ndarray: + n_samples = labels_true.shape[0] + # Handle edge case for single or empty samples safely + if n_samples <= 1: + return jnp.array(0.0, dtype=jnp.float32) + + contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters) + + ## Compute marginal sums (sums along rows and columns) + sum_true = jnp.sum(contingency, axis=1) + sum_pred = jnp.sum(contingency, axis=0) + + ## Calculate pairwise combinations using the matrix shortcut: nC2 = 0.5 * (sum(x^2) - N) + # True Positives pair combinations (tk) + tk = 0.5 * (jnp.sum(contingency ** 2) - n_samples) + ## Total pairs clustered together in ground truth (tr) + tr = 0.5 * (jnp.sum(sum_true ** 2) - n_samples) + ## Total pairs clustered together in predictions (tc) + tc = 0.5 * (jnp.sum(sum_pred ** 2) - n_samples) + + ## Compute FMI = tk / sqrt(tr * tc) + # Prevent division by zero if there are no pair splits/matches + denominator = jnp.sqrt(tr * tc) + fmi = jnp.where(denominator == 0.0, 0.0, tk / denominator) + return fmi + + +def measure_Vmeasure( ## V-Measure + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + beta: float = 1.0 +) -> jnp.ndarray: + """ + Calculates the V-Measure scoring metric for class conformity. This measurement compares + predicted cluster indices ("labels_pred") against ground truth indices ("labels_true") and + represents the harmonic mean of homogeneity (where each cluster contains only members of a single class) + as well as completeness (where all members of a given class are assigned to the same cluster). + This measurement (higher is better) lies in `[0,1]` where `1` indicates perfect, correct clustering. + + Args: + labels_true: 1D array of shape (n_samples,) with true integer class labels + + labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels + + beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity. + + Returns: + scalar V-measure of these two sets of indices + """ + ## Dynamically find dimensions up to a statically bounded maximum + n_classes = int(jnp.max(labels_true) + 1) + n_clusters = int(jnp.max(labels_pred) + 1) + return _measure_v_measure_score(labels_true, labels_pred, n_classes, n_clusters, beta) + + +@partial(jit, static_argnums=[2, 3, 4]) +def _measure_v_measure_score( ## V-Measure + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int, + beta: float = 1.0 +) -> jnp.ndarray: + n_samples = labels_true.shape[0] + + ## Handle edge case for single or empty samples safely + if n_samples <= 1: + return jnp.array(0.0, dtype=jnp.float32) + + contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters) + + ## Calculate Marginal Sums (Row and Column totals) + sum_true = jnp.sum(contingency, axis=1) + sum_pred = jnp.sum(contingency, axis=0) + + ## Compute Base Entropies H(True) and H(Pred) + p_true = sum_true / n_samples + h_true = -jnp.sum(jnp.where(p_true > 0.0, p_true * jnp.log(p_true), 0.0)) + + p_pred = sum_pred / n_samples + h_pred = -jnp.sum(jnp.where(p_pred > 0.0, p_pred * jnp.log(p_pred), 0.0)) + + ## Compute Joint Entropy H(True, Pred) + p_joint = contingency / n_samples + h_joint = -jnp.sum(jnp.where(p_joint > 0.0, p_joint * jnp.log(p_joint), 0.0)) + + ## Derive Conditional Entropies: H(True|Pred) and H(Pred|True) using identity rule + h_true_given_pred = h_joint - h_pred + h_pred_given_true = h_joint - h_true + + ## Compute Homogeneity (H) and Completeness (C) + ## If base entropy is 0, the metric is perfectly satisfied (1.0) + homogeneity = jnp.where(h_true == 0.0, 1.0, 1.0 - (h_true_given_pred / h_true)) + completeness = jnp.where(h_pred == 0.0, 1.0, 1.0 - (h_pred_given_true / h_pred)) + + ## Compute Weighted Harmonic Mean (V-Measure) + denominator = beta * homogeneity + completeness + + ## Prevent division by zero if both metrics are zero + v_measure = jnp.where( + denominator == 0.0, + 0.0, + (1.0 + beta) * homogeneity * completeness / denominator + ) + return v_measure diff --git a/ngclearn/utils/viz/classification_analysis.py b/ngclearn/utils/viz/classification_analysis.py index c49a3d0f..cce819f5 100644 --- a/ngclearn/utils/viz/classification_analysis.py +++ b/ngclearn/utils/viz/classification_analysis.py @@ -8,12 +8,13 @@ def visualize_confusion_heatmap( figure_title="Confusion Matrix", color_map="Blues", # "Greens" "Reds" fontsize=10, - show_percent=True + norm_by="none" ): _conf = confuse_matrix - pscale = 1. ## percentage scale - if show_percent: + if "recall" in norm_by: ## normalize by row (recall) _conf = (_conf / np.sum(_conf, axis=1, keepdims=True)) * 100 + elif "precision" in norm_by: ## normalize by col (precision) + _conf = (_conf / np.sum(_conf, axis=0, keepdims=True)) * 100 ## Initialize plot fig, ax = plt.subplots(figsize=(6, 6)) vmin = np.floor(np.min(_conf)) From abe7dfa78b7edbee92a16ab69567f9492efdc7eb Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 20 May 2026 00:37:27 -0400 Subject: [PATCH 11/23] minor edits/updates --- .../hebbian/gerstnerHebbianSynapse.py | 113 ++++++++++++++++++ ngclearn/utils/filters/gauss_filter.py | 27 +++-- 2 files changed, 132 insertions(+), 8 deletions(-) create mode 100644 ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py diff --git a/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py b/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py new file mode 100644 index 00000000..d1efca5a --- /dev/null +++ b/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py @@ -0,0 +1,113 @@ +import jax.numpy as jnp +from jax import random, jit + +from ngclearn import compilable +from ngclearn import Compartment +from ngclearn.components.synapses import DenseSynapse +from ngclearn.utils import tensorstats +from ngcsimlib import deprecate_args +#from ngclearn.utils.io_utils import save_pkl, load_pkl + +class GerstnerHebbianSynapse(DenseSynapse): + """ + A synapse component that implements Gerstner's general Hebbian + learning (Taylor) expansion (Equation 3 from Gerstner & Kistler, 2002). + + Note that this synpatic update model can recover several classical forms + of Hebbian-like update rules, including the covariance rule. + + There are other higher-order terms possible, i.e., \Theta(xy), such as + x * y2 and y x^2, etc. + + | c2_corr > 0 and c0 = c1_pre = c1_post = 0 => Hebbian update + | c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update + | c2_corr = 1 and c1_pre = -x_theta < 0 + + """ + def __init__( + self, + name, + shape, ## (post_dim, pre_dim) + eta=0.01, ## global step-size + coeffs=None, ## these configure which kind of Hebb learning is done + weight_init=None, + p_conn=1., + resist_scale=1., + sign_value=1., + batch_size=1, + **kwargs + ): + bias_init = None ## no biases are included in Gerster's formulation + super().__init__( + name, + shape=shape, + weight_init=weight_init, + bias_init=bias_init, + resist_scale=resist_scale, + p_conn=p_conn, + batch_size=batch_size, + **kwargs + ) + ## General Hebbian meta-parameters + self.eta = eta + self.sign_value = sign_value + + ## Expansion coefficients (c0, c1_pre, c1_post, c2_corr) + if coeffs is None: ## Default to standard bilinear Hebb + self.coeffs = { + 'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0 + } + else: + self.coeffs = coeffs + self.c0 = self.coeffs['c0'] + self.c1_pre = self.coeffs['c1_pre'] + self.c1_post = self.coeffs['c1_post'] + self.c2_corr = self.coeffs['c2_corr'] + + # Initialize Weights (using JAX PRNG) + #init_key, _ = random.split(self.key) + #w_init = random.normal(init_key, shape) * 0.05 + + # Compartments (ngc-learn state management) + #self.weights = Compartment(w_init) + self.pre = Compartment(jnp.zeros((1, shape[1]))) + self.post = Compartment(jnp.zeros((1, shape[0]))) + + @compilable + def evolve(self, **kwargs): + """ + Updates weights using the Gerstner general expansion. + Assumes pre_act and post_act compartments have been populated. + """ + # Retrieve current states + W = self.weights.get() + x = self.pre.get() # pre-synaptic activity (batch, pre_dim) + y = self.post.get() # post-synaptic activity (batch, post_dim) + batch_size = self.batch_size + + ## Bilinear Term (c2): correlation matrix + ### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim) + dW_corr = jnp.matmul(x.T, y) * (1./batch_size) + ## Linear pre-synaptic term (c1_pre) + ### Average over batch then broadcast to match weight matrix + dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size) + ## Linear post-synaptic term (c1_post) + dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size) + + ## Apply Equation 3 Taylor expansion + dW = (self.c0 * W + ## synaptic decay + self.c1_pre * dW_pre + ## bilinear term + self.c1_post * dW_post + ## pre-synaptic gating term + self.c2_corr * dW_corr ## post-synpatic gating term + ) + ## perform a step of Hebbian ascent + W = W + self.eta * dW + ## Update weights + self.weights.set(W) + + @compilable + def reset(self, **kwargs): + """Clears activity compartments""" + self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) ) + self.post.set( jnp.zeros((self.batch_size, self.shape[0])) ) + diff --git a/ngclearn/utils/filters/gauss_filter.py b/ngclearn/utils/filters/gauss_filter.py index f080883f..dcca211b 100644 --- a/ngclearn/utils/filters/gauss_filter.py +++ b/ngclearn/utils/filters/gauss_filter.py @@ -2,7 +2,6 @@ from jax import lax, jit from functools import partial - def _calc_gaussian_kernel_2D( ## internal co-routine sigma: float, radius: int @@ -21,7 +20,8 @@ def gaussian_filter( sigma_center: float, ## sigma1 sigma_surround: float, ## sigma2 kernel_size : int, ## radius - use_ratio=False ## if True, this becomes a ratio-of-Gaussians + use_ratio=False, ## if True, this becomes a ratio-of-Gaussians + edge_pad_mode="edge" ## "reflect" ) -> jnp.ndarray: """ Applies a difference-of-Gaussians filter to a batch of 2D images (of CxHxW tensor shape). @@ -40,27 +40,38 @@ def gaussian_filter( Returns: An output tensor of shape (B, C, H, W) """ - x = images + ## Pad spatial dimensions (H, W) using edge-clamping to remove artifacts + # Format for 4D (B, C, H, W): ((Before_B, After_B), (Before_C, After_C), (Before_H, After_H), (Before_W, After_W)) + padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size)) + padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode) + ## Construct two 2D Gaussian kernels k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) ## center kernel k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) ## surround kernel + ## Define dimension ordering for lax.conv ('NCHW' standard layout) dn = lax.ConvDimensionNumbers( lhs_spec=(0, 1, 2, 3), ## (batch, channel, height, width) rhs_spec=(0, 1, 2, 3), ## (out_channel, in_channel, height, width) out_spec=(0, 1, 2, 3) ## (batch, channel, height, width) ) - ## Perform spatial convolutions w/ edge padding to emulate 'SAME' behavior + + ## Extract channel count dynamically for independent channel-wise filtering + num_channels = images.shape[1] + + ## Perform spatial convolutions w/ 'VALID' padding on the edge-padded input blur_center = lax.conv_general_dilated( - x, k1, window_strides=(1, 1), padding=[(kernel_size, kernel_size), (kernel_size, kernel_size)], dimension_numbers=dn + padded_x, k1, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels ) blur_surround = lax.conv_general_dilated( - x, k2, window_strides=(1, 1), padding=[(kernel_size, kernel_size), (kernel_size, kernel_size)], dimension_numbers=dn + padded_x, k2, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels ) + ## Perform final filter calculation if use_ratio: eps = 1e-5 - output = blur_center / (blur_surround + eps) ## Compute kernel difference + output = blur_center / (blur_surround + eps) ## Compute kernel ratio else: - output = blur_center - blur_surround ## Compute kernel ratio + output = blur_center - blur_surround ## Compute kernel difference return output ## shape: (B, C, H, W) + From d527ab6f064f057441d40e8b892dc4faa9053cd3 Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Wed, 20 May 2026 21:29:21 -0400 Subject: [PATCH 12/23] Refactor ganglion cell (#151) * Add patch reconstruction function to ganglionCell.py Added a new function to reconstruct patches from input data, improving data handling in the RetinalGanglionCell class. * Add ratio of Gaussian filter function --- .../components/input_encoders/ganglionCell.py | 58 ++++++++++++++----- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index 562b6b5d..b7e58c05 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -22,6 +22,14 @@ def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1): dog = g1 - lmbda * g2 return dog #- jnp.mean(dog) + +def _create_ratio_of_gauss_filter(patch_shape, sigma, k=1.6): + g1 = _create_gaussian_filter(patch_shape, sigma=sigma) + g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k) + rog = g1 / (g2 + 1e-8) + return rog + + def _create_patches(obs, patch_shape, step_shape): """ Extract 2D patches from a batch of images using a sliding window. @@ -61,6 +69,29 @@ def _create_patches(obs, patch_shape, step_shape): return patches +def _reconstruct(patches, nx_ny, area_shape, patch_shape, step_shape): + # patches: (N, nx * ny, px, py) + + B = len(patches) + nx, ny = nx_ny + ix, iy = area_shape + px, py = patch_shape + sx, sy = step_shape + x = jnp.zeros((B, ix, iy)) + counts = jnp.zeros((ix, iy)) + + idx = 0 + for i in range(ny): + for j in range(nx): + di = i * sx + dj = j * sy + x = x.at[:, di:di + px, dj:dj + py].add(patches[:, idx]) + counts = counts.at[di:di + px, dj:dj + py].add(1.0) + idx += 1 + + return x / counts[None, :, :] + + class RetinalGanglionCell(JaxComponent): """ @@ -121,10 +152,18 @@ def __init__( self.step_shape = step_shape _filter = jnp.ones(self.patch_shape) - if filter_type == 'gaussian': + + if self.filter_type == 'gaussian': + print("filter type is ", self.filter_type) _filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) - elif filter_type == 'difference_of_gaussian': + + elif self.filter_type in ["difference_of_gaussian", "DoG"]: + print("filter type is difference of gaussian: f(x) = p1 - p2") _filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) + + elif self.filter_type in ["ratio_of_gaussian", "RoG"]: + print("filter type is ratio of gaussian: f(x) = p1 / p2") + _filter = _create_ratio_of_gauss_filter(patch_shape=self.patch_shape, sigma=sigma) # ═════════════════ compartments initial values ════════════════════ in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) @@ -158,27 +197,16 @@ def advance_state(self, t): self.outputs.set(outputs) + + @compilable def reset(self): ## reset core components/statistics - # self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy) out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) self.inputs.set(in_restVals) self.outputs.set(out_restVals) - # Viet: NOTE: we should not need this function since the reset function - # one could set the batch size then do reset - # @compilable - # def batched_reset(self, batch_size): - # in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - - # out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) - # self.n_cells * self.patch_shape[0] * self.patch_shape[1])) - - # self.inputs.set(in_restVals) - # self.outputs.set(out_restVals) - @classmethod def help(cls): ## component help function properties = { From f47549abae1779ac10d93f71d0185ca1e2f4bcd3 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 20 May 2026 21:51:19 -0400 Subject: [PATCH 13/23] mod to model_utils --- ngclearn/utils/model_utils.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index 9f22cc76..ce59a5ef 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -664,10 +664,34 @@ def softmax(x, tau=0.0): exp_x = jnp.exp(x - max_x) return exp_x / jnp.sum(exp_x, axis=1, keepdims=True) +@jit +def d_softmax(x): + """ + Derivative of the softmax function. + Note that this returns specifically the Jacobian tensor of softmax(x) w.r.t. + potential batch set of vectors (one per row). + + Args: + x: input (tensor) value (B x D) + + Returns: + output (tensor) derivative values (Jacobian with respect to input argument; B x D x D) + """ + ## caclulate softmax along feature dimension (axis=-1) + s = jax.nn.softmax(x, axis=-1) ## Shape: (B, D) + ## Batch-up diag(s); multiply s by 3D identity tensor + ## Shape: (B, D, 1) * (1, D, D) => (B, D, D) + diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1]) + ## Batched outer(s, s): Broadcasted multiplication + ## Shape: (B, D, 1) * (B, 1, D) => (B, D, D) + outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2) + return diag_s - outer_s ## return full final Jacobian + @jit def threshold_soft(x, lmbda): """ - A soft threshold routine applied to each dimension of input + A soft threshold routine applied to each dimension of input. + (Note that this function does not contain a complementary derivative.) Args: x: data to apply threshold function over @@ -684,7 +708,8 @@ def threshold_soft(x, lmbda): @jit def threshold_cauchy(x, lmbda): """ - A Cauchy distributional threshold routine applied to each dimension of input + A Cauchy distributional threshold routine applied to each dimension of input. + (Note that this function does not contain a complementary derivative.) Args: x: data to apply threshold function over From 63f3f7923da59fb4fe0f9eff53b0484b44b156df Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 20 May 2026 22:32:22 -0400 Subject: [PATCH 14/23] mod to model_utils --- ngclearn/utils/model_utils.py | 77 ++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index ce59a5ef..3cb09bd0 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -108,7 +108,7 @@ def create_function(fun_name, args=None): elif fun_name == "elu": fx = elu dfx = d_elu - elif fun_name == "silu": + elif fun_name == "silu": # NOTE: this is also the swish function fx = silu dfx = d_silu elif fun_name == "gelu": @@ -122,22 +122,20 @@ def create_function(fun_name, args=None): dfx = d_softplus elif fun_name == "softmax": fx = softmax - ## NOTE: below is an improper derivative proxy - ## correct dfx is a Jacobian of softmax (not currently supported!) - dfx = d_identity + dfx = d_softmax ## NOTE: this yields a Jacobian tensor Jx elif fun_name == "unit_threshold": fx = threshold ## default threshold is 1 (thus unit) dfx = d_threshold ## STE approximation elif "heaviside" in fun_name: fx = heaviside - dfx = d_heaviside ## STE approximation + dfx = d_heaviside ## NOTE: this is an STE approximation elif fun_name == "identity": fx = identity dfx = d_identity - else: + else: ## throw exception for un-supported activation raise RuntimeError( "Activation function (" + fun_name + ") is not recognized/supported!" - ) + ) return fx, dfx @partial(jit, static_argnums=[1]) @@ -214,7 +212,6 @@ def clamp_max(x, max_val): _x = x * mask + (1. - mask) * max_val return _x - @jit def one_hot(P): """ @@ -514,7 +511,7 @@ def d_softplus(x): output (tensor) derivative value (with respect to input argument) """ ## d/dx of softplus = logistic sigmoid - return nn.sigmoid(x) + return sigmoid(x) #nn.sigmoid(x) @jit def threshold(x, thr=1.): @@ -522,7 +519,7 @@ def threshold(x, thr=1.): @jit def d_threshold(x, thr=1.): - return x * 0. + 1. ## straight-thru estimator + return x * 0. + 1. ## NOTE: straight-thru estimator (STE) @jit def heaviside(x): @@ -530,15 +527,16 @@ def heaviside(x): @jit def d_heaviside(x): - return x * 0. + 1. ## straight-thru estimator + return x * 0. + 1. ## NOTE: straight-thru estimator (STE) @jit def sigmoid(x): - return nn.sigmoid(x) + sigm_x = 1./ (1. + jnp.exp(-x)) + return sigm_x #nn.sigmoid(x) @jit def d_sigmoid(x): - sigm_x = nn.sigmoid(x) ## pre-compute once + sigm_x = sigmoid(x) #nn.sigmoid(x) ## pre-compute once return sigm_x * (1. - sigm_x) def inverse_sigmoid(x, clip_bound=0.03): ## wrapper call for naming convention ease @@ -590,7 +588,9 @@ def d_swish(x, beta): @jit def silu(x): """ - Applies the sigmoid-weighted linear unit (SiLU or SiL) activation. + Applies the sigmoid-weighted linear unit (SiLU or SiL) activation. + Note that this is primarily a convenience wrapper function for + the `swish` activation. Args: x: data to transform via inverse logistic function @@ -607,7 +607,8 @@ def d_silu(x): @jit def gelu(x): """ - Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used). + Applies the Gaussian Error Linear Unit (GeLU) activation + (specifically, a fast approximation is used via a weighted `swish`). Args: x: data to transform via inverse logistic function @@ -635,7 +636,7 @@ def elu(x, alpha=1.): Returns: output of the GeLU activation """ - mask = x >= 0. + mask = x >= 0. ## pre-compute mask return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask) @jit @@ -653,7 +654,8 @@ def softmax(x, tau=0.0): Args: x: a (N x D) input argument (pre-activity) to the softmax operator - tau: probability sharpening/softening factor + tau: probability sharpening/softening factor, if > 0.; else, <= 0 disables + this (Default: 0.) Returns: a (N x D) probability distribution output block @@ -664,28 +666,47 @@ def softmax(x, tau=0.0): exp_x = jnp.exp(x - max_x) return exp_x / jnp.sum(exp_x, axis=1, keepdims=True) -@jit -def d_softmax(x): +@partial(jit, static_argnums=[2]) +def d_softmax(x, tau=0., vmap_form=False): ## temperature-controlled softmax derivative co-routine """ Derivative of the softmax function. - Note that this returns specifically the Jacobian tensor of softmax(x) w.r.t. + Note that this returns specifically the Jacobian tensor `Jx` of softmax(x) w.r.t. potential batch set of vectors (one per row). Args: x: input (tensor) value (B x D) + vmap_form: optional algorithm switch flag; if True, `Jx` is computed using + Jax vmap (Default: False) + Returns: output (tensor) derivative values (Jacobian with respect to input argument; B x D x D) """ + _m = tau > 0. + _tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0 + Jx = 0. ## d_softmax(x)/d_x is a Jacobian matrix per sample ## caclulate softmax along feature dimension (axis=-1) - s = jax.nn.softmax(x, axis=-1) ## Shape: (B, D) - ## Batch-up diag(s); multiply s by 3D identity tensor - ## Shape: (B, D, 1) * (1, D, D) => (B, D, D) - diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1]) - ## Batched outer(s, s): Broadcasted multiplication - ## Shape: (B, D, 1) * (B, 1, D) => (B, D, D) - outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2) - return diag_s - outer_s ## return full final Jacobian + s = softmax(x, tau=_tau) # nn.softmax(x, axis=-1) ## (BxD) + if not vmap_form: ### use pure tensorized batch-identity trick algorithm + diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1]) ## Shape: (BxDx1) * (1xDxD) => (BxDxD) + ## batched outer(s, s) ~> outer product for each batch vector + outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2) ## (BxDx1) * (Bx1xD) => (BxDxD) + Jx = (diag_s - outer_s) * (1. / _tau) + else: ### switch to vmap algorithm + ## calc outer product using einsum (clean and readable) + outer_s = jnp.einsum('bi,bj->bij', s, s) ## (BxDxD) + ## fast batched diagonal insertion via a diagonal mask + d = s.shape[-1] + diag_indices = jnp.arange(d) + ## jax.at subtracts outer product from diagonal + ## (s - s^2) for diagonal, (-s_i s_j) for off-diagonal + ## avoids constructing a giant identity matrix + jacobian = -outer_s + ## vmap over index updates across batch + def add_diag(J_matrix, s_vector): + return J_matrix.at[diag_indices, diag_indices].add(s_vector) + Jx = ( jax.vmap(add_diag)(jacobian, s) ) * (1. / _tau) + return Jx ## return full, final Jacobian @jit def threshold_soft(x, lmbda): From 23c717105c5ad04f5ed7c29a28f5cfe426f19777 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 21 May 2026 00:39:35 -0400 Subject: [PATCH 15/23] mod to model_utils --- ngclearn/utils/model_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index 3cb09bd0..b83c94e2 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -660,8 +660,11 @@ def softmax(x, tau=0.0): Returns: a (N x D) probability distribution output block """ - if tau > 0.0: - x = x / tau + #if tau > 0.0: + # x = x / tau + _m = tau > 0. + _tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0 + x = x * (1./ _tau) max_x = jnp.max(x, axis=1, keepdims=True) exp_x = jnp.exp(x - max_x) return exp_x / jnp.sum(exp_x, axis=1, keepdims=True) From 631d65e7773d3f51941a63d25ddbac21e3c7ca05 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 21 May 2026 00:57:26 -0400 Subject: [PATCH 16/23] clean-up to model_utils --- ngclearn/utils/model_utils.py | 98 +++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index b83c94e2..aa8ef9cb 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -73,8 +73,7 @@ def create_function(fun_name, args=None): Args: fun_name: string name of activation function to produce; Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6", - "elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside", - "identity" + "elu", "silu", "gelu", "softplus", "softmax", "unit_threshold", "heaviside", "identity" Returns: function fx, first derivative of function (w.r.t. input) dfx @@ -125,10 +124,10 @@ def create_function(fun_name, args=None): dfx = d_softmax ## NOTE: this yields a Jacobian tensor Jx elif fun_name == "unit_threshold": fx = threshold ## default threshold is 1 (thus unit) - dfx = d_threshold ## STE approximation + dfx = d_threshold ## NOTE: STE approximation elif "heaviside" in fun_name: fx = heaviside - dfx = d_heaviside ## NOTE: this is an STE approximation + dfx = d_heaviside ## NOTE: STE approximation elif fun_name == "identity": fx = identity dfx = d_identity @@ -139,7 +138,16 @@ def create_function(fun_name, args=None): return fx, dfx @partial(jit, static_argnums=[1]) -def bkwta(x, nWTA=5): #5 10 15 #K=50): +def bkwta(x, nWTA=5): ## binarized k-winner-take-all function + """ + The binarized K winner-take-all (K-WTA) function: + + Args: + x: input (tensor) value (real-valued) + + Returns: + output (tensor) value (binary values) + """ values, indices = lax.top_k(x, nWTA) # Note: we do not care to sort the indices kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats @@ -251,7 +259,7 @@ def chebyshev_norm(d, axis=-1, keepdims=False): @jit def binarize(data, threshold=0.5): """ - Converts the vector *data* to its binary equivalent + Converts the vector *data* to its binary equivalent. Args: data: the data to binarize (real-valued) @@ -354,10 +362,13 @@ def d_telu(x): ex = jnp.exp(x) tanh_ex = jnp.tanh(ex) return tanh_ex + x * ex * (1.0 - tanh_ex ** 2) + @jit def sine(x, omega_0=30): """ - f(x) = sin(x * omega_0). + The sine function, parameterized by frequency `omega`: + + | f(x) = sin(x * omega_0). Args: x: input (tensor) value @@ -370,14 +381,15 @@ def sine(x, omega_0=30): @jit def d_sine(x, omega_0=30): """ - frequency = omega_0 - frequency * cos(x * frequency). + The derivative of the sine function: + + | f'(x) = frequency * cos(x * frequency); where frequency = omega_0 Args: x: input (tensor) value Returns: - output (tensor) value + output (tensor) derivative value (with respect to input) """ return omega_0 * jnp.cos(omega_0 * x) @@ -489,7 +501,9 @@ def d_relu6(x): @jit def softplus(x): """ - The softplus elementwise function. + The softplus elementwise function: + + | f(x) = ln(1 + exp(-x)) Args: x: input (tensor) value @@ -515,27 +529,89 @@ def d_softplus(x): @jit def threshold(x, thr=1.): + """ + The threshold function (or Heaviside but with a non-zero boundary): + + | f(x) = 1 if x >= thr, otherwise 0 (for x < thr) + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ return (x >= thr).astype(jnp.float32) @jit def d_threshold(x, thr=1.): + """ + Derivative of the threshold function; specifically, this employs the + straight-through estimator (STE) as a proxy/surrogate derivative instead. + + Args: + x: input (tensor) value + + Returns: + output (tensor) derivative value (with respect to input argument) + """ return x * 0. + 1. ## NOTE: straight-thru estimator (STE) @jit def heaviside(x): + """ + The Heaviside function: + + | f(x) = 1 if x >= 0, otherwise 0 (for x < 0) + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ return (x >= 0.).astype(jnp.float32) @jit def d_heaviside(x): + """ + Derivative of the Heaviside function; specifically, this employs the + straight-through estimator (STE) as a proxy/surrogate derivative instead. + + Args: + x: input (tensor) value + + Returns: + output (tensor) derivative value (with respect to input argument) + """ return x * 0. + 1. ## NOTE: straight-thru estimator (STE) @jit def sigmoid(x): + """ + The sigmoid / logistic-link function: + + | f(x) = 1/(1 + exp(-x) + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ sigm_x = 1./ (1. + jnp.exp(-x)) return sigm_x #nn.sigmoid(x) @jit def d_sigmoid(x): + """ + Derivative of the sigmoid / logistic-link function. + + Args: + x: input (tensor) value + + Returns: + output (tensor) derivative value (with respect to input argument) + """ sigm_x = sigmoid(x) #nn.sigmoid(x) ## pre-compute once return sigm_x * (1. - sigm_x) From 91ff5f1f03696210564e5857293c6d9a1a64b0c4 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 25 May 2026 02:18:39 -0400 Subject: [PATCH 17/23] reverted back rate-cell/gauss-cell to v3.0.0 states --- .../neurons/graded/gaussianErrorCell.py | 349 +++++----- .../components/neurons/graded/rateCell.py | 652 +++++++++--------- 2 files changed, 498 insertions(+), 503 deletions(-) diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index d12bd33f..9e072de5 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -1,175 +1,174 @@ -# %% - -from ngclearn.components.jaxComponent import JaxComponent -from jax import numpy as jnp, jit -from ngclearn import compilable #from ngcsimlib.parser import compilable -from ngclearn import Compartment #from ngcsimlib.compartment import Compartment - -class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell - """ - A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically, - this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood). - - | --- Cell Input Compartments: --- - | mu - predicted value (takes in external signals) - | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2 - | target - desired/goal value (takes in external signals) - | modulator - modulation signal (takes in optional external signals) - | mask - binary/gating mask to apply to error neuron calculations - | --- Cell Output Compartments: --- - | L - local loss function embodied by this cell - | dmu - derivative of L w.r.t. mu - | dSigma - derivative of L w.r.t. Sigma - | dtarget - derivative of L w.r.t. target - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - batch_size: batch size dimension of this cell (Default: 1) - - sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; - Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses - to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument - (Default: 1) - """ - def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): - super().__init__(name, **kwargs) - - ## Layer Size Setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - sigma_shape = (1,1) - if not isinstance(sigma, float) and not isinstance(sigma, int): - sigma_shape = jnp.array(sigma).shape - self.sigma_shape = sigma_shape - self.shape = shape - self.batch_size = batch_size - self.n_units = n_units - - ## Convolution shape setup - self.width = self.height = n_units - - ## Compartment setup - restVals = jnp.zeros(_shape) - self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment - self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire - self.dmu = Compartment(restVals) # derivative mean - _Sigma = jnp.zeros(sigma_shape) - self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance") - self.dSigma = Compartment(_Sigma) - self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire - self.dtarget = Compartment(restVals) # derivative target - self.modulator = Compartment(restVals + 1.0) # to be set/consumed - self.mask = Compartment(restVals + 1.0) - - @staticmethod - def _eval_log_density(target, mu, Sigma): ## Gaussian log likelihood - ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix - _dmu = (target - mu) - #_numerator = 1. # 0.5 - log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma) - return log_density, _dmu ## return density and raw delta - - @compilable - def advance_state(self, dt): ## compute Gaussian error cell output (fixed-point) - # Get the variables - mu = self.mu.get() - target = self.target.get() - Sigma = self.Sigma.get() - modulator = self.modulator.get() - mask = self.mask.get() - - # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit - # behavior of the local cost functional: - # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2 - # but should support full log likelihood of the multivariate Gaussian with covariance of different types - # TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE - # (using integration time constant dt) - - L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) - ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu) - dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma - dtarget = -dmu # reverse of e ## -(target - mu)/Sigma - dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma - - dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... - dtarget = dtarget * modulator * mask - mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t - - # Update compartments - self.dmu.set(dmu) - self.dtarget.set(dtarget) - self.dSigma.set(dSigma) - self.L.set(jnp.squeeze(L)) - self.mask.set(mask) - - @compilable - def reset(self): ## reset core components/statistics - self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member - - @compilable - def batched_reset(self, batch_size): - _shape = (batch_size, self.shape[0]) - if len(self.shape) > 1: - _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) - restVals = jnp.zeros(_shape) - dmu = restVals - dtarget = restVals - dSigma = jnp.zeros(self.sigma_shape) - target = restVals - mu = restVals - modulator = mu + 1. - L = 0. #jnp.zeros((1, 1)) - mask = jnp.ones(_shape) - - self.dmu.set(dmu) - self.dtarget.set(dtarget) - self.dSigma.set(dSigma) - if not self.target.targeted: - self.target.set(target) - if not self.mu.targeted: - self.mu.set(mu) - self.modulator.set(modulator) - self.L.set(L) - self.mask.set(mask) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "GaussianErrorCell - computes mismatch/error signals at " - "each time step t (between a `target` and a prediction `mu`)" - } - compartment_props = { - "inputs": - {"mu": "External input prediction value(s)", - "Sigma": "External variance/covariance prediction value(s)", - "target": "External input target signal value(s)", - "modulator": "External input modulatory/scaling signal(s)", - "mask": "External binary/gating mask to apply to signals"}, - "outputs": - {"L": "Local loss / free-energy value embodied by this error-cell", - "dmu": "first derivative of loss w.r.t. prediction value(s)", - "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)", - "dtarget": "first derivative of loss w.r.t. target value(s)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "sigma": "External input variance value (currently fixed and not learnable)" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "Gaussian(x=target; mu, sigma)", - "hyperparameters": hyperparams} - return info - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = GaussianErrorCell("X", 9) - print(X) +# %% + +from ngclearn.components.jaxComponent import JaxComponent +from jax import numpy as jnp, jit +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment + +class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell + """ + A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically, + this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood). + + | --- Cell Input Compartments: --- + | mu - predicted value (takes in external signals) + | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2 + | target - desired/goal value (takes in external signals) + | modulator - modulation signal (takes in optional external signals) + | mask - binary/gating mask to apply to error neuron calculations + | --- Cell Output Compartments: --- + | L - local loss function embodied by this cell + | dmu - derivative of L w.r.t. mu + | dSigma - derivative of L w.r.t. Sigma + | dtarget - derivative of L w.r.t. target + + Args: + name: the string name of this cell + + n_units: number of cellular entities (neural population size) + + batch_size: batch size dimension of this cell (Default: 1) + + sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; + Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses + to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument + (Default: 1) + """ + def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): + super().__init__(name, **kwargs) + + ## Layer Size Setup + _shape = (batch_size, n_units) ## default shape is 2D/matrix + if shape is None: + shape = (n_units,) ## we set shape to be equal to n_units if nothing provided + else: + _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor + sigma_shape = (1,1) + if not isinstance(sigma, float) and not isinstance(sigma, int): + sigma_shape = jnp.array(sigma).shape + self.sigma_shape = sigma_shape + self.shape = shape + self.n_units = n_units + self.batch_size = batch_size + + ## Convolution shape setup + self.width = self.height = n_units + + ## Compartment setup + restVals = jnp.zeros(_shape) + self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment + self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire + self.dmu = Compartment(restVals) # derivative mean + _Sigma = jnp.zeros(sigma_shape) + self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance") + self.dSigma = Compartment(_Sigma) + self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire + self.dtarget = Compartment(restVals) # derivative target + self.modulator = Compartment(restVals + 1.0) # to be set/consumed + self.mask = Compartment(restVals + 1.0) + + @staticmethod + def eval_log_density(target, mu, Sigma): + ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix + _dmu = (target - mu) + #_numerator = 1. # 0.5 + log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma) + return log_density, _dmu ## return density and raw delta + + @compilable + def advance_state(self, dt): ## compute Gaussian error cell output + # Get the variables + mu = self.mu.get() + target = self.target.get() + Sigma = self.Sigma.get() + modulator = self.modulator.get() + mask = self.mask.get() + + # Move Gaussian cell dynamics one step forward. Specifically, this + # routine emulates the error unit + ''' + ## This commented-out block of code should be adapted to replace the + ## five lines below it in future iterations (more accurate/flexible) + L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu) + dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma + dtarget = -dmu # reverse of e ## -(target - mu)/Sigma + dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma + ''' + _dmu = (target - mu) # e (error unit) + dmu = _dmu / Sigma + dtarget = -dmu # reverse of e + dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma + L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + #L = GaussianErrorCell.eval_log_density(target, mu, Sigma) + + dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... + dtarget = dtarget * modulator * mask + mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t + + # Update compartments + self.dmu.set(dmu) + self.dtarget.set(dtarget) + self.dSigma.set(dSigma) + self.L.set(jnp.squeeze(L)) + self.mask.set(mask) + + @compilable + def reset(self): ## reset core components/statistics + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + restVals = jnp.zeros(_shape) + dmu = restVals + dtarget = restVals + dSigma = jnp.zeros(self.sigma_shape) + target = restVals + mu = restVals + modulator = mu + 1. + L = 0. #jnp.zeros((1, 1)) + mask = jnp.ones(_shape) + + self.dmu.set(dmu) + self.dtarget.set(dtarget) + self.dSigma.set(dSigma) + self.target.set(target) + self.mu.set(mu) + self.modulator.set(modulator) + self.L.set(L) + self.mask.set(mask) + + @classmethod + def help(cls): ## component help function + properties = { + "cell_type": "GaussianErrorCell - computes mismatch/error signals at " + "each time step t (between a `target` and a prediction `mu`)" + } + compartment_props = { + "inputs": + {"mu": "External input prediction value(s)", + "Sigma": "External variance/covariance prediction value(s)", + "target": "External input target signal value(s)", + "modulator": "External input modulatory/scaling signal(s)", + "mask": "External binary/gating mask to apply to signals"}, + "outputs": + {"L": "Local loss / free-energy value embodied by this error-cell", + "dmu": "first derivative of loss w.r.t. prediction value(s)", + "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)", + "dtarget": "first derivative of loss w.r.t. target value(s)"}, + } + hyperparams = { + "n_units": "Number of neuronal cells to model in this layer", + "batch_size": "Batch size dimension of this component", + "sigma": "External input variance value (currently fixed and not learnable)" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "Gaussian(x=target; mu, sigma)", + "hyperparameters": hyperparams} + return info + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + X = GaussianErrorCell("X", 9) + print(X) diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 98f60325..3dcbdcc8 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -1,328 +1,324 @@ -# %% - -from jax import numpy as jnp, random, jit - -from ngclearn import compilable #from ngcsimlib.parser import compilable -from ngclearn import Compartment #from ngcsimlib.compartment import Compartment -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils.model_utils import create_function, threshold_soft, \ - threshold_cauchy -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2, step_rk4 -from ngcsimlib.logger import info - - -def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = jnp.sign(z) ## d/dx of Laplace is signum - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = (z * 2)/(1. + jnp.square(z)) - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = jnp.exp(-jnp.square(z)) * z * 2 - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = z # * 2 ## Default: assume Gaussian - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -# @jit -def _modulate(j, dfx_val): - """ - Apply a signal modulator to j (typically of the form of a derivative/dampening function) - - Args: - j: current/stimulus value to modulate - - dfx_val: modulator signal - - Returns: - modulated j value - """ - return j * dfx_val - -# @partial(jit, static_argnames=["integType", "priorType"]) -def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): - """ - Runs leaky rate-coded state dynamics one step in time. - - Args: - dt: integration time constant - - j: input (bottom-up) electrical/stimulus current - - j_td: modulatory (top-down) electrical/stimulus pressure - - z: current value of membrane/state - - tau_m: membrane/state time constant - - leak_gamma: strength of leak to apply to membrane/state - - integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) - - priorType: scale-shift prior distribution to impose over neural dynamics - - Returns: - New value of membrane/state for next time step - """ - _dfz_fns = { - 0: lambda t, z, params: _dfz_internal_gaussian(z, *params), - 1: lambda t, z, params: _dfz_internal_laplace(z, *params), - 2: lambda t, z, params: _dfz_internal_cauchy(z, *params), - 3: lambda t, z, params: _dfz_internal_exp(z, *params), - } - _dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian) - _step_fns = { - 0: step_euler, - 1: step_rk2, - 2: step_rk4, - } - _step_fn = _step_fns.get(integType, step_euler) - params = (j, j_td, tau_m, leak_gamma) - _, _z = _step_fn(0., z, _dfz_fn, dt, params) - return _z - -# @jit -def _run_cell_stateless(j): - """ - A simplification of running a stateless set of dynamics over j (an identity - functional form of dynamics). - - Args: - j: stimulus to do nothing to - - Returns: - the stimulus - """ - return j + 0 - -class RateCell(JaxComponent): ## Rate-coded/real-valued cell - """ - A non-spiking cell driven by the gradient dynamics of neural generative - coding-driven predictive processing. - - The specific differential equation that characterizes this cell - is (for adjusting v, given current j, over time) is: - - | tau_m * dz/dt = lambda * prior(z) + (j + j_td) - | where j is the set of general incoming input signals (e.g., message-passed signals) - | and j_td is taken to be the set of top-down pressure signals - - | --- Cell Input Compartments: --- - | j - input pressure (takes in external signals) - | j_td - input/top-down pressure input (takes in external signals) - | --- Cell State Compartments --- - | z - rate activity - | --- Cell Output Compartments: --- - | zF - post-activation function activity, i.e., fx(z) - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - tau_m: membrane/state time constant (milliseconds) - - prior: a kernel for specifying the type of centered scale-shift distribution - to impose over neuronal dynamics, applied to each neuron or - dimension within this component (Default: ("gaussian", 0)); this is - a tuple with 1st element containing a string name of the distribution - one wants to use while the second value is a `leak rate` scalar - that controls the influence/weighting that this distribution - has on the dynamics; for example, ("laplacian, 0.001") means that a - centered laplacian distribution scaled by `0.001` will be injected - into this cell's dynamics ODE each step of simulated time - - :Note: supported scale-shift distributions include "laplacian", - "cauchy", "exp", and "gaussian" - - act_fx: string name of activation function/nonlinearity to use - - output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) - - integration_type: type of integration to use for this cell's dynamics; - current supported forms include "euler" (Euler/RK-1 integration) - and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") - - :Note: setting the integration type to the midpoint method will - increase the accuray of the estimate of the cell's evolution - at an increase in computational cost (and simulation time) - - resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) - """ - - def __init__( - self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.), - integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): - jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)} - this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)} - super().__init__(name, **jax_comp_kwargs) - - ## membrane parameter setup (affects ODE integration) - self.output_scale = output_scale - self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode - self.is_stateful = is_stateful - if isinstance(tau_m, float): - if tau_m <= 0: ## trigger stateless mode - self.is_stateful = False - priorType, leakRate = prior - priorTypeDict = { - "gaussian": 0, - "laplacian": 1, - "cauchy": 2, - "exp": 3 - } - self.priorType = priorTypeDict.get(priorType, 0) - self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) - thresholdType, thr_lmbda = threshold - self.thresholdType = thresholdType ## type of thresholding function to use - self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics - self.resist_scale = resist_scale ## a "resistance" scaling factor - - ## integration properties - self.integrationType = integration_type - self.intgFlag = get_integrator_code(self.integrationType) - - ## Layer size setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - self.shape = shape - self.n_units = n_units - self.batch_size = batch_size - - omega_0 = None - if act_fx == "sine": - omega_0 = this_class_kwargs["omega_0"] - self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) - - # compartments (state of the cell & parameters will be updated through stateless calls) - restVals = jnp.zeros(_shape) - self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current - self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity - self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure - self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity - - @compilable - def advance_state(self, dt): - # Get the compartment values - j = self.j.get() - j_td = self.j_td.get() - z = self.z.get() - - #if tau_m > 0.: - if self.is_stateful: - ### run a step of integration over neuronal dynamics - ## Notes: - ## self.pressure <-- "top-down" expectation / contextual pressure - ## self.current <-- "bottom-up" data-dependent signal - dfx_val = self.dfx(z) - j = _modulate(j, dfx_val) ## TODO: make this optional (for NGC circuit dynamics) - j = j * self.resist_scale - tmp_z = _run_cell( - dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag, - priorType=self.priorType - ) - ## apply optional thresholding sub-dynamics - if self.thresholdType == "soft_threshold": - tmp_z = threshold_soft(tmp_z, self.thr_lmbda) - elif self.thresholdType == "cauchy_threshold": - tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda) - z = tmp_z ## pre-activation function value(s) - zF = self.fx(z) * self.output_scale ## post-activation function value(s) - else: - ## run in "stateless" mode (when no membrane time constant provided) - j_total = j + j_td - z = _run_cell_stateless(j_total) - zF = self.fx(z) * self.output_scale - - # Update compartments - self.j.set(j) - self.j_td.set(j_td) - self.z.set(z) - self.zF.set(zF) - - @compilable - def reset(self): ## reset core components/statistics - self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member - - @compilable - def batched_reset(self, batch_size): - _shape = (batch_size, self.shape[0]) - if len(self.shape) > 1: - _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) - restVals = jnp.zeros(_shape) - self.j.set(restVals) - self.j_td.set(restVals) - self.z.set(restVals) - self.zF.set(restVals) - - # def save(self, directory, **kwargs): - # ## do a protected save of constants, depending on whether they are floats or arrays - # tau_m = (self.tau_m if isinstance(self.tau_m, float) - # else jnp.ones([[self.tau_m]])) - # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) - # else jnp.ones([[self.priorLeakRate]])) - # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) - # else jnp.ones([[self.resist_scale]])) - # - # file_name = directory + "/" + self.name + ".npz" - # jnp.savez(file_name, - # tau_m=tau_m, priorLeakRate=priorLeakRate, - # resist_scale=resist_scale) #, key=self.key.value) - # - # def load(self, directory, seeded=False, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # data = jnp.load(file_name) - # ## constants loaded in - # self.tau_m = data['tau_m'] - # self.priorLeakRate = data['priorLeakRate'] - # self.resist_scale = data['resist_scale'] - # #if seeded: - # # self.key.set(data['key']) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "RateCell - evolves neurons according to rate-coded/" - "continuous dynamics " - } - compartment_props = { - "inputs": - {"j": "External input stimulus value(s)", - "j_td": "External top-down input stimulus value(s); these get " - "multiplied by the derivative of f(x), i.e., df(x)"}, - "states": - {"z": "Update to rate-coded continuous dynamics; value at time t"}, - "outputs": - {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "tau_m": "Cell state/membrane time constant", - "prior": "What kind of kurtotic prior to place over neuronal dynamics?", - "act_fx": "Elementwise activation function to apply over cell state `z`", - "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", - "integration_type": "Type of numerical integration to use for the cell dynamics", - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", - "hyperparameters": hyperparams} - return info - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = RateCell("X", 9, 0.03) - print(X) +# %% + +from jax import numpy as jnp, random, jit + +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils.model_utils import create_function, threshold_soft, \ + threshold_cauchy +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ + step_euler, step_rk2, step_rk4 +from ngcsimlib.logger import info + + +def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = jnp.sign(z) ## d/dx of Laplace is signum + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = (z * 2)/(1. + jnp.square(z)) + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = jnp.exp(-jnp.square(z)) * z * 2 + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = z # * 2 ## Default: assume Gaussian + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +# @jit +def _modulate(j, dfx_val): + """ + Apply a signal modulator to j (typically of the form of a derivative/dampening function) + + Args: + j: current/stimulus value to modulate + + dfx_val: modulator signal + + Returns: + modulated j value + """ + return j * dfx_val + +# @partial(jit, static_argnames=["integType", "priorType"]) +def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): + """ + Runs leaky rate-coded state dynamics one step in time. + + Args: + dt: integration time constant + + j: input (bottom-up) electrical/stimulus current + + j_td: modulatory (top-down) electrical/stimulus pressure + + z: current value of membrane/state + + tau_m: membrane/state time constant + + leak_gamma: strength of leak to apply to membrane/state + + integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) + + priorType: scale-shift prior distribution to impose over neural dynamics + + Returns: + New value of membrane/state for next time step + """ + _dfz_fns = { + 0: lambda t, z, params: _dfz_internal_gaussian(z, *params), + 1: lambda t, z, params: _dfz_internal_laplace(z, *params), + 2: lambda t, z, params: _dfz_internal_cauchy(z, *params), + 3: lambda t, z, params: _dfz_internal_exp(z, *params), + } + _dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian) + _step_fns = { + 0: step_euler, + 1: step_rk2, + 2: step_rk4, + } + _step_fn = _step_fns.get(integType, step_euler) + params = (j, j_td, tau_m, leak_gamma) + _, _z = _step_fn(0., z, _dfz_fn, dt, params) + return _z + +# @jit +def _run_cell_stateless(j): + """ + A simplification of running a stateless set of dynamics over j (an identity + functional form of dynamics). + + Args: + j: stimulus to do nothing to + + Returns: + the stimulus + """ + return j + 0 + +class RateCell(JaxComponent): ## Rate-coded/real-valued cell + """ + A non-spiking cell driven by the gradient dynamics of neural generative + coding-driven predictive processing. + + The specific differential equation that characterizes this cell + is (for adjusting v, given current j, over time) is: + + | tau_m * dz/dt = lambda * prior(z) + (j + j_td) + | where j is the set of general incoming input signals (e.g., message-passed signals) + | and j_td is taken to be the set of top-down pressure signals + + | --- Cell Input Compartments: --- + | j - input pressure (takes in external signals) + | j_td - input/top-down pressure input (takes in external signals) + | --- Cell State Compartments --- + | z - rate activity + | --- Cell Output Compartments: --- + | zF - post-activation function activity, i.e., fx(z) + + Args: + name: the string name of this cell + + n_units: number of cellular entities (neural population size) + + tau_m: membrane/state time constant (milliseconds) + + prior: a kernel for specifying the type of centered scale-shift distribution + to impose over neuronal dynamics, applied to each neuron or + dimension within this component (Default: ("gaussian", 0)); this is + a tuple with 1st element containing a string name of the distribution + one wants to use while the second value is a `leak rate` scalar + that controls the influence/weighting that this distribution + has on the dynamics; for example, ("laplacian, 0.001") means that a + centered laplacian distribution scaled by `0.001` will be injected + into this cell's dynamics ODE each step of simulated time + + :Note: supported scale-shift distributions include "laplacian", + "cauchy", "exp", and "gaussian" + + act_fx: string name of activation function/nonlinearity to use + + output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) + + integration_type: type of integration to use for this cell's dynamics; + current supported forms include "euler" (Euler/RK-1 integration) + and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") + + :Note: setting the integration type to the midpoint method will + increase the accuray of the estimate of the cell's evolution + at an increase in computational cost (and simulation time) + + resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) + """ + + def __init__( + self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.), + integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): + jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)} + this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)} + super().__init__(name, **jax_comp_kwargs) + + ## membrane parameter setup (affects ODE integration) + self.output_scale = output_scale + self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode + self.is_stateful = is_stateful + if isinstance(tau_m, float): + if tau_m <= 0: ## trigger stateless mode + self.is_stateful = False + priorType, leakRate = prior + priorTypeDict = { + "gaussian": 0, + "laplacian": 1, + "cauchy": 2, + "exp": 3 + } + self.priorType = priorTypeDict.get(priorType, 0) + self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) + thresholdType, thr_lmbda = threshold + self.thresholdType = thresholdType ## type of thresholding function to use + self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics + self.resist_scale = resist_scale ## a "resistance" scaling factor + + ## integration properties + self.integrationType = integration_type + self.intgFlag = get_integrator_code(self.integrationType) + + ## Layer size setup + _shape = (batch_size, n_units) ## default shape is 2D/matrix + if shape is None: + shape = (n_units,) ## we set shape to be equal to n_units if nothing provided + else: + _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor + self.shape = shape + self.n_units = n_units + self.batch_size = batch_size + + omega_0 = None + if act_fx == "sine": + omega_0 = this_class_kwargs["omega_0"] + self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) + + # compartments (state of the cell & parameters will be updated through stateless calls) + restVals = jnp.zeros(_shape) + self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current + self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity + self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure + self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity + + @compilable + def advance_state(self, dt): + # Get the compartment values + j = self.j.get() + j_td = self.j_td.get() + z = self.z.get() + + #if tau_m > 0.: + if self.is_stateful: + ### run a step of integration over neuronal dynamics + ## Notes: + ## self.pressure <-- "top-down" expectation / contextual pressure + ## self.current <-- "bottom-up" data-dependent signal + dfx_val = self.dfx(z) + j = _modulate(j, dfx_val) + j = j * self.resist_scale + tmp_z = _run_cell( + dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag, + priorType=self.priorType + ) + ## apply optional thresholding sub-dynamics + if self.thresholdType == "soft_threshold": + tmp_z = threshold_soft(tmp_z, self.thr_lmbda) + elif self.thresholdType == "cauchy_threshold": + tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda) + z = tmp_z ## pre-activation function value(s) + zF = self.fx(z) * self.output_scale ## post-activation function value(s) + else: + ## run in "stateless" mode (when no membrane time constant provided) + j_total = j + j_td + z = _run_cell_stateless(j_total) + zF = self.fx(z) * self.output_scale + + # Update compartments + self.j.set(j) + self.j_td.set(j_td) + self.z.set(z) + self.zF.set(zF) + + @compilable + def reset(self): #, batch_size, shape): #n_units + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + restVals = jnp.zeros(_shape) + self.j.set(restVals) + self.j_td.set(restVals) + self.z.set(restVals) + self.zF.set(restVals) + + # def save(self, directory, **kwargs): + # ## do a protected save of constants, depending on whether they are floats or arrays + # tau_m = (self.tau_m if isinstance(self.tau_m, float) + # else jnp.ones([[self.tau_m]])) + # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) + # else jnp.ones([[self.priorLeakRate]])) + # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) + # else jnp.ones([[self.resist_scale]])) + # + # file_name = directory + "/" + self.name + ".npz" + # jnp.savez(file_name, + # tau_m=tau_m, priorLeakRate=priorLeakRate, + # resist_scale=resist_scale) #, key=self.key.value) + # + # def load(self, directory, seeded=False, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # ## constants loaded in + # self.tau_m = data['tau_m'] + # self.priorLeakRate = data['priorLeakRate'] + # self.resist_scale = data['resist_scale'] + # #if seeded: + # # self.key.set(data['key']) + + @classmethod + def help(cls): ## component help function + properties = { + "cell_type": "RateCell - evolves neurons according to rate-coded/" + "continuous dynamics " + } + compartment_props = { + "inputs": + {"j": "External input stimulus value(s)", + "j_td": "External top-down input stimulus value(s); these get " + "multiplied by the derivative of f(x), i.e., df(x)"}, + "states": + {"z": "Update to rate-coded continuous dynamics; value at time t"}, + "outputs": + {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, + } + hyperparams = { + "n_units": "Number of neuronal cells to model in this layer", + "batch_size": "Batch size dimension of this component", + "tau_m": "Cell state/membrane time constant", + "prior": "What kind of kurtotic prior to place over neuronal dynamics?", + "act_fx": "Elementwise activation function to apply over cell state `z`", + "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", + "integration_type": "Type of numerical integration to use for the cell dynamics", + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", + "hyperparameters": hyperparams} + return info + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + X = RateCell("X", 9, 0.03) + print(X) From 172a68d1f9051031939c351450a20cd626aee2f3 Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Wed, 27 May 2026 00:21:18 -0400 Subject: [PATCH 18/23] Add effective_dim.py with participation_ratio function (#152) --- ngclearn/utils/analysis/effective_dim.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 ngclearn/utils/analysis/effective_dim.py diff --git a/ngclearn/utils/analysis/effective_dim.py b/ngclearn/utils/analysis/effective_dim.py new file mode 100644 index 00000000..e3c98a87 --- /dev/null +++ b/ngclearn/utils/analysis/effective_dim.py @@ -0,0 +1,11 @@ +from jax import numpy as jnp + +def participation_ratio(Z): + Zc = Z - Z.mean(axis=0, keepdims=True) + cov = (Zc.T @ Zc) / (Zc.shape[0] - 1) + + tr = jnp.trace(cov) + tr2_cov = tr * tr + cov2_tr = jnp.trace(cov @ cov) + + return tr2_cov / cov2_tr if cov2_tr > 0 else float("nan") From b8015fccd0131b357ec675264d28d83705d7020b Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 27 May 2026 00:24:14 -0400 Subject: [PATCH 19/23] minor tweak to eff-dim measure --- ngclearn/utils/analysis/effective_dim.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ngclearn/utils/analysis/effective_dim.py b/ngclearn/utils/analysis/effective_dim.py index e3c98a87..a7b5aca0 100644 --- a/ngclearn/utils/analysis/effective_dim.py +++ b/ngclearn/utils/analysis/effective_dim.py @@ -1,6 +1,16 @@ from jax import numpy as jnp -def participation_ratio(Z): +def participation_ratio(latent_codes): + """ + Calculates the participation ratio coefficient for a set of latent codes + + Args: + latent_codes: a set of (N x D) latent code vectors (one row per vector code) + + Returns: + scalar measurement of the effective dimension + """ + Z = latent_codes Zc = Z - Z.mean(axis=0, keepdims=True) cov = (Zc.T @ Zc) / (Zc.shape[0] - 1) @@ -9,3 +19,4 @@ def participation_ratio(Z): cov2_tr = jnp.trace(cov @ cov) return tr2_cov / cov2_tr if cov2_tr > 0 else float("nan") + From 5647c686b6e400cb5ab783300613ea3b2314f07b Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Mon, 1 Jun 2026 11:31:49 -0400 Subject: [PATCH 20/23] Implement effective rank calculation in effective_dim.py (#153) Adds a function to calculate the effective rank of a code matrix using Shannon entropy. --- ngclearn/utils/analysis/effective_dim.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ngclearn/utils/analysis/effective_dim.py b/ngclearn/utils/analysis/effective_dim.py index a7b5aca0..7b46f318 100644 --- a/ngclearn/utils/analysis/effective_dim.py +++ b/ngclearn/utils/analysis/effective_dim.py @@ -20,3 +20,27 @@ def participation_ratio(latent_codes): return tr2_cov / cov2_tr if cov2_tr > 0 else float("nan") + + + +def rankme(Z, eps=1e-7): + """ + Calculates the effective rank of for a code matrix Z + effective rank = exp(Shannon entropy), from Garrido, Balestriero, + Najman & LeCun, "RankMe: Assessing the Downstream Performance of Pretrained + Self-Supervised Representations by Their Rank" (ICML 2023, arXiv:2210.02885). + + Args: + latent_codes: a set of (N x D) latent code vectors (one row per vector code) + + Returns: + scalar measurement of the effective dimension + """ + + singular_values = jnp.linalg.svd(Z, compute_uv=False) ## singular values of Z + sum_singular_vals = jnp.sum(singular_values) ## L1 + if sum_singular_vals <= 0: + return float("nan") + p = singular_values / sum_singular_vals + eps ## L1-normalized singular value + shannon_entropy = -jnp.sum(p * jnp.log(p)) ## Shannon entropy + return jnp.exp(shannon_entropy) ## exp(Shannon entropy) = effective rank From df08087e234c7ac4b3f0016b508acd7165351f7f Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Thu, 4 Jun 2026 13:18:36 -0400 Subject: [PATCH 21/23] Update JaxProcessesMixin.py Added comments, and small naming update to the flag to locally store the state to avoid confusion with the global state. --- ngclearn/utils/JaxProcessesMixin.py | 50 +++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/ngclearn/utils/JaxProcessesMixin.py b/ngclearn/utils/JaxProcessesMixin.py index 3b849f09..3958dfb3 100644 --- a/ngclearn/utils/JaxProcessesMixin.py +++ b/ngclearn/utils/JaxProcessesMixin.py @@ -9,6 +9,10 @@ from ngcsimlib._src.process.baseProcess import BaseProcess class JaxCompiledMethod(CompiledMethod): + """ + A wrapper for a compiled method that includes jax's jit wrapped. Used + exclusively by the mixin and shouldn't be used elsewhere. + """ def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals): super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals) self._fn = jax.jit(fn) @@ -16,10 +20,19 @@ def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals): @property def source_fn(self): + """ + The source method not wrapped in jit + """ return self._fn_source @classmethod def wrap(cls, compiledMethod: CompiledMethod): + """ + Helper method to expand on a base compiled method + Args: + compiledMethod: The method to be expanded upon + Returns: the JaxCompiledMethod based on the input + """ return cls(compiledMethod._fn, compiledMethod.ast, compiledMethod.auxiliary_ast, @@ -28,7 +41,16 @@ def wrap(cls, compiledMethod: CompiledMethod): class JaxProcessesMixin: + """ + A mixin for the base Process that adds JAX functionality such as scan and + implicit jit wrapping + """ def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs): + """ + Look at the BaseProcess class for information about other arguments + Args: + use_jit: a flag for if the process should implicitly jit wrap + """ super().__init__(name, *args, **kwargs) self._previous_result = None self._previous_state = None @@ -36,27 +58,51 @@ def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs): @property def previous_result(self): + """ + Stores and returns the last result of scan (the second returned value) + """ return self._previous_result @property def previous_state(self): + """ + Stores and returns the last returned state of scan (the first returned + value) + """ return self._previous_state def clear(self): + """ + Clears out the previous result and state from scan + """ self._previous_result = None self._previous_state = None - def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True): + def scan(self: "BaseProcess", inputs, current_state=None, store_state: bool = True, store_results: bool = True): + """ + Runs the process through jax's scan method + Args: + inputs: The inputs for scan (use pack rows to generate), must be a jax array + current_state: Optional, the current state of the model, if none uses current global state + store_state: Optional flag, should the final state be stored in the process + store_results: Optional flag, should the final result be stored in the process + + Returns: the final state, the final result + + """ state = current_state or stateManager.state final_state, result = jax.lax.scan(self.run.compiled, state, inputs) - if save_state: + if store_state: self._previous_state = final_state if store_results: self._previous_result = result return final_state, result def compile(self: "baseProcess"): + """ + For use by the compiler + """ super().compile() if self._use_jit: self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled) From 0ba704937785c59cb410a23ea33834af117854ca Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 9 Jun 2026 18:19:10 -0400 Subject: [PATCH 22/23] updates to utils, including integration of useful filters; included update to knn-probe to support cosine-distance --- ngclearn/components/__init__.py | 3 + .../components/input_encoders/__init__.py | 6 +- .../input_encoders/bernoulliCell.py | 7 +- .../components/input_encoders/poissonCell.py | 9 +- ngclearn/utils/analysis/knn_probe.py | 61 +++++++----- ngclearn/utils/filters/__init__.py | 2 + .../utils/filters/cortical_gauss_filter.py | 94 +++++++++++++++++++ ngclearn/utils/filters/gauss_filter.py | 80 +++++++++------- ngclearn/utils/metric_utils.py | 59 ++++++++++++ ngclearn/utils/viz/dim_reduce.py | 35 ++++++- 10 files changed, 291 insertions(+), 65 deletions(-) create mode 100644 ngclearn/utils/filters/cortical_gauss_filter.py diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 0b81afdc..94a2c116 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -30,6 +30,9 @@ from .input_encoders.ganglionCell import RetinalGanglionCell from .input_encoders.latencyCell import LatencyCell from .input_encoders.phasorCell import PhasorCell +#from .input_encoders.populationCoderCell import PopulationCoderCell +#from .input_encoders.gridCell import GridCell +#from .input_encoders.placeCell import PlaceCell ## point to synapse component types from .synapses.denseSynapse import DenseSynapse diff --git a/ngclearn/components/input_encoders/__init__.py b/ngclearn/components/input_encoders/__init__.py index 9af3241b..bbee5280 100644 --- a/ngclearn/components/input_encoders/__init__.py +++ b/ngclearn/components/input_encoders/__init__.py @@ -1,7 +1,9 @@ from .bernoulliCell import BernoulliCell from .poissonCell import PoissonCell from .latencyCell import LatencyCell -from .ganglionCell import RetinalGanglionCell from .phasorCell import PhasorCell - +from .ganglionCell import RetinalGanglionCell +#from .populationCoderCell import PopulationCoderCell +#from .gridCell import GridCell +#from .placeCell import PlaceCell diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 132f7d92..68e87539 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -27,7 +27,12 @@ class BernoulliCell(JaxComponent): """ def __init__( - self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs + self, + name: str, + n_units: int, + batch_size: int = 1, + key: Union[jax.Array, None] = None, + **kwargs ): super().__init__(name=name, key=key) diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 810776ab..333479ae 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -32,8 +32,13 @@ class PoissonCell(JaxComponent): @deprecate_args(max_freq="target_freq") def __init__( - self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1, - key: Union[jax.Array, None] = None, **kwargs + self, + name: str, + n_units: int, + target_freq: float = 63.75, + batch_size: int = 1, + key: Union[jax.Array, None] = None, + **kwargs ): super().__init__(name=name, key=key) diff --git a/ngclearn/utils/analysis/knn_probe.py b/ngclearn/utils/analysis/knn_probe.py index 54f061c9..b9656371 100644 --- a/ngclearn/utils/analysis/knn_probe.py +++ b/ngclearn/utils/analysis/knn_probe.py @@ -2,28 +2,34 @@ import numpy as np from ngcsimlib import deprecate_args from ngclearn.utils.analysis.probe import Probe -from ngclearn.utils.model_utils import kwta from jax import jit, random, numpy as jnp, lax, nn from functools import partial as bind -from ngclearn.utils.distribution_generator import DistributionGenerator - -@bind(jax.jit, static_argnums=[2, 3]) -def _run_knn_probe(_embeddings, Wx, K, dist_order=2): - ## Notes: - ### We do some 3D tensor math to handle a batch of predictions that need to be made - ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories - _Wx = jnp.expand_dims(Wx, axis=0) ## 3D tensor format of KNN params (1 x N x D) - embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## 3D projection of input signals (B x 1 x D) - D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D) - ## get batched (negative) distance measurements - dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1) - ## else, default -> euclidean - ### Note: negative distance allows us to find minimal points w/ maximal functions - dist = -jnp.squeeze(dist, axis=2) ## (B x N) - ## now get K winners per sample in batch + + +@bind(jax.jit, static_argnums=[2, 3, 4]) +def _run_knn_probe(_embeddings, Wx, K, dist_order=2, dist_metric="minkowski"): + if dist_metric == "cosine": + ## normalize the incoming batch embeddings along the feature axis (axis 1) + ### add a tiny epsilon to prevent division-by-zero errors + eps = 1e-12 + embed_norm = _embeddings / (jnp.linalg.norm(_embeddings, axis=1, keepdims=True) + eps) + ## normalize the internal memory database array (axis 1) + Wx_norm = Wx / (jnp.linalg.norm(Wx, axis=1, keepdims=True) + eps) + ## compute batched cosine similarity using standard 2D matrix multiplication + ### (B x D) @ (D x N) -> yields a (B x N) similarity matrix directly! + dist = jnp.matmul(embed_norm, Wx_norm.T) + else: # Default back to your Minkowski setup + _Wx = jnp.expand_dims(Wx, axis=0) ## (1 x N x D) + embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## (B x 1 x D) + D = embed_tensor - _Wx ## (B x N x D) + dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1) + dist = -jnp.squeeze(dist, axis=2) ## (B x N) + + # lax.top_k naturally grabs the maximums (highest similarity or smallest negative distance) values, indices = lax.top_k(dist, K) return values, indices + class KNNProbe(Probe): """ This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of @@ -82,16 +88,21 @@ def __init__( self.vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction if vote_style == "mean": self.vote_fx = 1 + self.distance_function = distance_function - dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean - if "euclidean" in dist_fun.lower(): + dist_fun, dist_order = distance_function + self.dist_metric = "minkowski" # default tracker + if "cosine" in dist_fun.lower(): + self.dist_metric = "cosine" + dist_order = 2 ## fallback assignment + elif "euclidean" in dist_fun.lower(): dist_order = 2 elif "manhattan" in dist_fun.lower(): dist_order = 1 elif "chebyshev" in dist_fun.lower(): dist_order = jnp.inf - ## TODO: add in cosine-distance (and maybe Mahalanobis distance) self.dist_order = dist_order ## set distance order p + self.predictor_type = predictor_type self.pred_fx = 0 if "regressor" == predictor_type: @@ -102,14 +113,18 @@ def __init__( Wx = Wy = jnp.ones((1, 1)) ## Wy will be assumed to be one-hot encoded self.probe_params = (Wx, Wy) - def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this + def process(self, embeddings, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) - Wx, Wy = self.probe_params ## pull out KNN parameters - values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_order) + Wx, Wy = self.probe_params + + # Pass the explicit metric string directly to the JIT-compiled loop + values, indices = _run_knn_probe( + _embeddings, Wx, self.K, self.dist_order, self.dist_metric + ) ## do K-neighbor voting scheme (find mode/frequency prediction) Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1])) diff --git a/ngclearn/utils/filters/__init__.py b/ngclearn/utils/filters/__init__.py index 0dc711b9..7f778d24 100644 --- a/ngclearn/utils/filters/__init__.py +++ b/ngclearn/utils/filters/__init__.py @@ -1 +1,3 @@ from .gauss_filter import gaussian_filter +from .cortical_gauss_filter import cortical_gaussian_filter + diff --git a/ngclearn/utils/filters/cortical_gauss_filter.py b/ngclearn/utils/filters/cortical_gauss_filter.py new file mode 100644 index 00000000..5a57eaa9 --- /dev/null +++ b/ngclearn/utils/filters/cortical_gauss_filter.py @@ -0,0 +1,94 @@ +""" +Support for a function/pure "cortical" Gaussian filter - this can be configured to facilitate a filter that +engages in a difference-of-Gaussians-like or ratio-of-Gaussians-like process +""" + +import jax.numpy as jnp +from jax import lax, jit +from functools import partial + +def _calc_gaussian_kernel_2D( + sigma: float, ## standard deviation of kernel + radius: int ## controls shape of kernel +) -> jnp.ndarray: ## internal co-routine for Gaussian kernel + ## create a normalized 2D Gaussian kernel with shape (1, 1, 2*radius+1, 2*radius+1) + x = jnp.arange(-radius, radius + 1) + xx, yy = jnp.meshgrid(x, x) + kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2) + kernel = kernel / jnp.sum(kernel) + return kernel[jnp.newaxis, jnp.newaxis, :, :] + +@partial(jit, static_argnums=[3, 4, 5]) +def cortical_gaussian_filter( + images: jnp.ndarray, ## expected input shape : (B, C, H, W) + sigma_center: float, ## center excitation + sigma_surround: float, ## surround inhibition + kernel_size: int, ## kernel radius + use_ratio:bool=False, ## triggers either RoG vs DoG mode + semi_sat_constant:float=0.1, ## sigma_h (controls contrast-gain) + excitation_exp:float=2.0, ## p-exponent (typically in range of 1.5 - 2.0) + inhibition_exp:float=2.0, ## q-exponent + edge_pad_mode:str='edge' +) -> jnp.ndarray: + """ + Applies a configurable rectified Gaussian filter (either difference-of-Gaussians, i.e., DoG, or ratio-of-Gaussians, + i.e., RoG) to a tensor batch of 2D images (each of CxHxW tensor shape/format). Note that this variant filter + means that DoG mode acts more as a (half-wave) rectified subtraction of two Gaussian kernels and RoG mode acts more + as simple form of divisive normalization. + + Args: + images: input image tensor of shape (B, C, H, W) + + sigma_center: standard deviation for narrow / center blur + + sigma_surround: standard deviation for wide / surround blur + + kernel_size: kernel radius (window size will be `2*radius + 1`) + + use_ratio: if True, this filter applies a ratio-of-Gaussians (RoG) filter (Default: False) + + semi_sat_constant: suppresses amplified micro-variations in sensory/image space + + excitation_exp: p-exponent (typically in range of 1.5 - 2.0) (Default: 2.0) + + inhibition_exp: q-exponent (typically in range of 1.5 - 2.0) (Default: 2.0) + + edge_pad_mode: type of image edge-clamping/padding to use, either "edge" or "reflect" (Default: "edge") + + Returns: + An output tensor of shape (B, C, H, W) + """ + ## set up edge-artifact padding correction + padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size)) + padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode) + + ## set up Gaussian kernels + k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) + k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) + + dn = lax.ConvDimensionNumbers( + lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3) + ) + num_channels = images.shape[1] + + ## apply standard spatial convolutions + blur_center = lax.conv_general_dilated( + padded_x, k1, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels + ) + blur_surround = lax.conv_general_dilated( + padded_x, k2, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels + ) + if use_ratio: + ## cortically-plausible divisive normalization (or a cortical ratio-of-Gaussians; RoG) + ### first, apply half-wave rectification + rectified_center = jnp.maximum(0.0, blur_center) + rectified_surround = jnp.maximum(0.0, blur_surround) + ## next, apply nonlinear response exponents + numerator = jnp.power(rectified_center, excitation_exp) + denominator = jnp.power(rectified_surround, inhibition_exp) + (semi_sat_constant ** 2) + output = numerator / denominator ## calculate ratio + else: ## cortically-plausible difference-of-Gaussians (cortical DoG) + ### this is modeled via rectified linear subtraction under an output threshold (0 in the case below) + output = jnp.maximum(0.0, blur_center - blur_surround) + return output + diff --git a/ngclearn/utils/filters/gauss_filter.py b/ngclearn/utils/filters/gauss_filter.py index dcca211b..ef592b52 100644 --- a/ngclearn/utils/filters/gauss_filter.py +++ b/ngclearn/utils/filters/gauss_filter.py @@ -1,77 +1,93 @@ +""" +Support for a function/pure Gaussian filter - this can be configured to facilicate +difference-of-Gaussians (DoG) or ratio-of-Gaussians (RoG). +""" + import jax.numpy as jnp from jax import lax, jit from functools import partial -def _calc_gaussian_kernel_2D( ## internal co-routine - sigma: float, - radius: int +def _calc_gaussian_kernel_2D( ## internal co-routine for Gaussian kernel + sigma: float, ## standard deviation of kernel + radius: int ## controls shape of kernel ) -> jnp.ndarray: - ## Generate a (normalized) 2D Gaussian kernel of shape: (1, 1, 2*radius+1, 2*radius+1) + ## create a normalized 2D Gaussian kernel with shape (1, 1, 2*radius+1, 2*radius+1) x = jnp.arange(-radius, radius + 1) xx, yy = jnp.meshgrid(x, x) kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2) kernel = kernel / jnp.sum(kernel) - # Reshape to (out_channels, in_channels, height, width) for lax.conv_general_dilated + ## reshape output to: (out_channels, in_channels, height, width) to support lax.conv_general_dilated w/in filter return kernel[jnp.newaxis, jnp.newaxis, :, :] @partial(jit, static_argnums=[3, 4]) -def gaussian_filter( +def gaussian_filter( ## core filter routine images: jnp.ndarray, ## input image batch sigma_center: float, ## sigma1 sigma_surround: float, ## sigma2 kernel_size : int, ## radius - use_ratio=False, ## if True, this becomes a ratio-of-Gaussians - edge_pad_mode="edge" ## "reflect" + use_ratio:bool=False, ## if True, this becomes a ratio-of-Gaussians + edge_pad_mode:str="edge" ## "reflect" ) -> jnp.ndarray: """ - Applies a difference-of-Gaussians filter to a batch of 2D images (of CxHxW tensor shape). + Applies a configurable Gaussian filter (either difference-of-Gaussians or ratio-of-Gaussians) to a tensor batch of + 2D images (of CxHxW tensor shape). Args: - images: Input image tensor of shape (B, C, H, W) + images: input image tensor of shape (B, C, H, W) + + sigma_center: standard deviation for narrow / center blur - sigma_center: Standard deviation for narrow / center blur + sigma_surround: standard deviation for wide / surround blur - sigma_surround: Standard deviation for wide / surround blur + kernel_size: kernel radius (window size will be `2*radius + 1`) - kernel_size: Kernel radius (window size will be `2*radius + 1`) + use_ratio: if True, this filter applies a ratio-of-Gaussians (RoG) filter (Default: False) - use_ratio: whether or not to use a ratio-of-Gaussians filter (Default: False) + edge_pad_mode: type of image edge-clamping/padding to use, either "edge" or "reflect" (Default: "edge") Returns: An output tensor of shape (B, C, H, W) """ - ## Pad spatial dimensions (H, W) using edge-clamping to remove artifacts - # Format for 4D (B, C, H, W): ((Before_B, After_B), (Before_C, After_C), (Before_H, After_H), (Before_W, After_W)) + ## pad spatial dimensions (H, W) using edge/reflect-clamping in order to remove artifacts + ### format for 4D tensor (B, C, H, W) => + ### result: ((Before_B, After_B), (Before_C, After_C), (Before_H, After_H), (Before_W, After_W)) padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size)) padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode) - ## Construct two 2D Gaussian kernels + ## construct two 2D Gaussian kernels k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) ## center kernel k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) ## surround kernel - ## Define dimension ordering for lax.conv ('NCHW' standard layout) + ## define dimension ordering for lax.conv ('NCHW' standard layout) dn = lax.ConvDimensionNumbers( lhs_spec=(0, 1, 2, 3), ## (batch, channel, height, width) rhs_spec=(0, 1, 2, 3), ## (out_channel, in_channel, height, width) out_spec=(0, 1, 2, 3) ## (batch, channel, height, width) ) + num_channels = images.shape[1] ## get channel count dynamically for independent channel-wise filtering - ## Extract channel count dynamically for independent channel-wise filtering - num_channels = images.shape[1] - - ## Perform spatial convolutions w/ 'VALID' padding on the edge-padded input - blur_center = lax.conv_general_dilated( - padded_x, k1, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels + ## below performs spatial convolutions w/ "VALID" padding on the edge-padded input + blur_center = lax.conv_general_dilated( ## center Gaussian + padded_x, + k1, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=dn, + feature_group_count=num_channels ) - blur_surround = lax.conv_general_dilated( - padded_x, k2, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels + blur_surround = lax.conv_general_dilated( ## surround Gaussian + padded_x, + k2, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=dn, + feature_group_count=num_channels ) - ## Perform final filter calculation - if use_ratio: + if use_ratio: ## apply ratio-of-Gaussians (RoG) eps = 1e-5 - output = blur_center / (blur_surround + eps) ## Compute kernel ratio - else: - output = blur_center - blur_surround ## Compute kernel difference - return output ## shape: (B, C, H, W) + output = blur_center / (blur_surround + eps) ## calculate kernel ratio + else: ## apply difference-of-Gaussians (DoG) + output = blur_center - blur_surround ## calculate kernel difference + return output ## final shape: (B, C, H, W) diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index 38100611..e30f3d9c 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -464,6 +464,65 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 bce = jnp.mean(bce) return bce +@partial(jit, static_argnums=[1]) +def measure_hoyer_sparsity(codes: jnp.ndarray, preserve_batch: bool=False) -> float: + """ + Measures the Hoyer sparsity for a set of latent codes. + Hoyer sparsity lies in [0, 1], where a value of 0.0 indicates if something is dense and + a value of 1 indicates something is extremely sparse. + + Args: + codes: matrix (shape: N x D) of non-negative codes to measure + sparsity of (per row); D is flattened latent code size + + preserve_batch: if True, will return one score per sample in batch + (Default: False), otherwise, returns scalar mean score + + Returns: + an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise + """ + # Flatten everything past the batch dimension + x = jnp.reshape(codes, (codes.shape[0], -1)) + N = x.shape[1] + + l1 = jnp.sum(jnp.abs(x), axis=1) + l2 = jnp.sqrt(jnp.sum(jnp.square(x), axis=1) + 1e-8) # epsilon to avoid division by zero + + hoyer = (jnp.sqrt(N) - (l1 / l2)) / (jnp.sqrt(N) - 1.0) + if not preserve_batch: + hoyer = jnp.mean(hoyer) # calc average sparsity across set/batch + return hoyer + +@partial(jit, static_argnums=[1]) +def measure_excess_kurtosis(codes: jnp.ndarray, preserve_batch: bool=False) -> float: + """ + Measures the peak and heavy-tailedness of a set of neural activation codes. Note that + higher values (> 0) indicate sparse, localized 'high-burst' activations. + + Args: + codes: matrix (shape: N x D) of non-negative codes to measure + sparsity of (per row) + + preserve_batch: if True, will return one score per sample in batch + (Default: False), otherwise, returns scalar mean score + + Returns: + an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise + """ + x = jnp.reshape(codes, (codes.shape[0], -1)) + mean = jnp.mean(x, axis=1, keepdims=True) ## 1st moment + variance = jnp.var(x, axis=1, keepdims=True) ## 2nd moment + + ## 4th central moment divided by variance squared + fourth_moment = jnp.mean(jnp.power(x - mean, 4), axis=1, keepdims=True) + kurtosis = fourth_moment / (jnp.square(variance) + 1e-8) ## kurtosis of distribution + excess_kurtosis = kurtosis - 3.0 ## calc "excess kurtosis" by subtracting 3 + if not preserve_batch: + excess_kurtosis = jnp.mean(excess_kurtosis) ## calc avg excess-kurtosis over set/batch + return excess_kurtosis + + +### class conformity metrics ### @partial(jit, static_argnums=[2, 3]) def _compute_contingency_table( ## vectorized construction of contingency matrix diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py index 3f32057d..3300feef 100755 --- a/ngclearn/utils/viz/dim_reduce.py +++ b/ngclearn/utils/viz/dim_reduce.py @@ -28,7 +28,12 @@ def extract_pca_latents(vectors): ## PCA mapping routine z_2D = vectors return z_2D -def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine +def extract_tsne_latents( + vectors, + perplexity=30, + n_pca_comp=32, + batch_size=500 +): ## tSNE mapping routine """ Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) visualization space via the t-distributed stochastic neighbor embedding algorithm (t-SNE). This algorithm also uses PCA to produce an @@ -41,9 +46,9 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): perplexity: the perplexity control factor for t-SNE (Default: 30) n_pca_comp: number of PCA top components (sorted by eigen-values) to retain/extract before continuing - with t-SNE dimensionality reduction + with t-SNE dimensionality reduction (Default: 32) - batch_size: number of sampled embedding vectors to use per iteration of online internal PCA + batch_size: number of sampled embedding vectors to use per iteration of online internal PCA (Default: 500) Returns: a matrix (K x 2) of projected vectors (to 2D space) @@ -67,7 +72,15 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): z_2D = vectors return z_2D -def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., cmap=None): +def plot_latents( + code_vectors, + labels, + plot_fname="2Dcode_plot.jpg", + alpha=1., + cmap=None, + xaxis_title=None, + yaxis_title=None +): """ Produces a label-overlaid (label map to distinct colors) scatterplot for visualizing two-dimensional latent codes (produced by either PCA or t-SNE). @@ -84,6 +97,11 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., c alpha: alpha intensity level to present colors in scatterplot cmap: custom color-map to provide + + xaxis_title: string denoting title to place for X-axis (Default: None) + + yaxis_title: string denoting title to place for Y-axis (Default: None) + """ curr_backend = plt.rcParams["backend"] matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting @@ -98,10 +116,17 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., c if _cmap is None: _cmap = default_cmap #print("> USING DEFAULT CMAP!") - plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha) + plt.scatter( + code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha + ) colorbar = plt.colorbar() #colorbar.set_alpha(1) #plt.draw_all() + if xaxis_title is not None: + plt.xlabel(xaxis_title, fontsize=16, fontweight="bold") + if yaxis_title is not None: + plt.ylabel(yaxis_title, fontsize=16, fontweight="bold") + plt.grid() plt.savefig("{0}".format(plot_fname), dpi=300) plt.clf() From ad0cc85ea3f3a6b2377fc98f3e2c24a1a64ae03a Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 9 Jun 2026 18:27:42 -0400 Subject: [PATCH 23/23] nudge for ngclearn docs to v3.1.1 minor update --- README.md | 2 +- pyproject.toml | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 757b9862..760a4933 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ $ python install -e . **Version:**
-3.1.0 +3.1.1 Author: Alexander G. Ororbia II
diff --git a/pyproject.toml b/pyproject.toml index bd1c7194..3458b230 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" # using setuptool building engine [project] name = "ngclearn" -version = "3.1.0" +version = "3.1.1" description = "Simulation software for building and analyzing computational neuroscience models, brain-inspired computing systems, and NeuroAI agents." authors = [ {name = "Alexander Ororbia", email = "ago@cs.rit.edu"}, diff --git a/requirements.txt b/requirements.txt index 978bd1fa..d68c12aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.26.4 scikit-learn>=1.6.1 scipy>=1.14.1 matplotlib>=3.9.4 -# patchify # patchify has issues with pip installation +# patchify ## note: patchify has issues with pip installation jax>=0.4.28 jaxlib>=0.4.28 ngcsimlib>=3.1.0