This is a PyTorch framework for constructing functional spiking recurrent neural networks from continuous rate models, based on the framework presented in this paper
The SpikeRNN framework consists of two complementary packages with a modern task-based architecture:
- rate: Continuous-variable rate RNN package for training models on cognitive tasks
- spiking: Spiking RNN package for converting rate models to biologically realistic networks
- Implementation of rate-based RNNs with Dale's principle support
- Modular task classes (GoNogoTask, XORTask, ManteTask)
- TaskFactory for dynamic task creation
- Multiple cognitive task implementations (Go-NoGo, XOR, Mante)
- Rate-to-spike conversion maintaining task performance
- Biologically realistic leaky integrate-and-fire (LIF) neurons
- Spiking task evaluation classes (GoNogoSpikingTask, XORSpikingTask, ManteSpikingTask)
- SpikingTaskFactory for task-based evaluation
- Scaling factor optimization for optimal conversion
The framework features a modular task-based design that separates cognitive tasks from models:
- Easy Extensibility: Add new tasks without modifying core model code
- Consistent Interface: All tasks follow the same abstract interface
- Factory Pattern: Dynamic task creation and discovery
Install both rate and spiking packages:
git clone https://github.com/NuttidaLab/spikeRNN.git
cd spikeRNN
pip install -e .After installation, you can import both packages:
from rate import FR_RNN_dale, set_gpu, create_default_config
from spiking import LIF_network_fnc, lambda_grid_search, evaluate_task
# Task-based architecture
from rate import TaskFactory
from spiking import SpikingTaskFactory
from rate.tasks import GoNogoTask, XORTask, ManteTask
from spiking.tasks import GoNogoSpikingTask, XORSpikingTask, ManteSpikingTaskfrom spikeRNN import TaskFactory
# Create task settings
settings = {
'T': 200, # Trial duration
'stim_on': 50, # Stimulus onset
'stim_dur': 25, # Stimulus duration
'DeltaT': 1 # Sampling rate
}
# Create a Go/NoGo task using the factory
task = TaskFactory.create_task('go_nogo', settings)
# Generate a complete trial (stimulus + target + label)
stimulus, target, label = task.simulate_trial()
print(f"Generated {task.__class__.__name__} trial with label: {label}")The framework provides 2 levels of evaluation:
from spiking import SpikingTaskFactory, evaluate_task
# Direct task evaluation (when you have a network instance, not necessarily trained)
spiking_task = SpikingTaskFactory.create_task('go_nogo')
performance = spiking_task.evaluate_performance(spiking_rnn, n_trials=100)
print(f"Accuracy: {performance['overall_accuracy']:.2f}")
# High-level interface (when you have model files with trained weights)
performance = evaluate_task(
task_name='go_nogo',
model_dir='models/go-nogo',
n_trials=100,
save_plots=True
)
# Command line interface (for scripts and automation)
# python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/Create custom rate-based tasks:
from rate.tasks import AbstractTask
class MyCustomTask(AbstractTask):
def validate_settings(self):
# Validate required settings
pass
def generate_stimulus(self, trial_type=None, seed=False):
# Generate custom stimulus
return stimulus, label
def generate_target(self, label, seed=False):
# Generate custom target
return target
# Use your custom task
custom_task = MyCustomTask(settings)
stimulus, target, label = custom_task.simulate_trial()Create custom spiking evaluation tasks:
from spiking.tasks import AbstractSpikingTask, SpikingTaskFactory
class MyCustomSpikingTask(AbstractSpikingTask):
def get_default_settings(self):
return {'T': 200, 'custom_param': 1.0}
def get_sample_trial_types(self):
return ['type_a', 'type_b'] # For visualization
def generate_stimulus(self, trial_type=None):
# Generate stimulus logic
return stimulus, label
def evaluate_performance(self, spiking_rnn, n_trials=100):
# Multi-trial performance evaluation
return {'accuracy': 0.85, 'n_trials': n_trials}
# Register and use with evaluation system
SpikingTaskFactory.register_task('my_custom', MyCustomSpikingTask)
# Now works with eval_tasks.py
python -m spiking.eval_tasks --task my_custom --model_dir models/custom/- Python >= 3.7
- PyTorch >= 1.8.0
- NumPy >= 1.16.4
- SciPy >= 1.3.1
- Matplotlib >= 3.0.0
- Train Rate RNN: Use the
ratepackage to train continuous-variable RNNs on cognitive tasks - Save as .mat: Export trained model in MATLAB format with all required parameters
- Optimize Scaling: Use
lambda_grid_search()to find optimal rate-to-spike conversion parameters - Convert to Spiking: Use
LIF_network_fnc()to convert to biologically realistic spiking networks - Evaluate Performance: Compare spiking and rate network performance on tasks
- Analyze Dynamics: Use spike analysis tools to study neural dynamics
from rate import FR_RNN_dale, set_gpu
from rate import TaskFactory
# Set up device and network
device = set_gpu('0', 0.4)
net = FR_RNN_dale(200, 0.2, 0.2, w_in, som_N=0, w_dist='gaus',
gain=1.5, apply_dale=True, w_out=w_out, device=device)
# Use task-based API to generate data
settings = {'T': 200, 'stim_on': 50, 'stim_dur': 25, 'DeltaT': 1}
task = TaskFactory.create_task('go_nogo', settings)
u, target, label = task.simulate_trial()
# ... training code ...from spiking import LIF_network_fnc, lambda_grid_search
from spiking.eval_tasks import evaluate_task
import numpy as np
# Optimize scaling factor
lambda_grid_search(
model_dir='models/go-nogo',
task_name='go-nogo',
n_trials=100,
scaling_factors=list(np.arange(25, 76, 5))
)
# Evaluate performance
performance = evaluate_task(
task_name='go_nogo',
model_dir='models/go-nogo/'
)Binary decision task requiring response inhibition:
- Go trials: Respond to stimulus
- NoGo trials: Withhold response
- Tests impulse control and decision-making
Temporal exclusive OR requiring working memory:
- Two sequential binary inputs
- Output XOR result after delay
- Tests working memory and logic operations
Context-dependent sensory integration:
- Multiple sensory modalities
- Context determines relevant modality
- Tests flexible cognitive control
The rate package save models in two formats:
.pthfiles: PyTorch format for rate model continuation.matfiles: MATLAB format containing all parameters used for spiking conversion
- Complete connectivity masks (
m,som_m) - Neuron type assignments (
inh,exc) - Time constant parameters (
taus,taus_gaus) - Initial weight states (
w0) - Network size and architecture (
N)
If you use this framework in your research, please cite:
@article{kim2019neural,
title={Neural population dynamics underlying motor learning transfer},
author={Kim, T. D. and Lian, T. and Yang, G. R.},
journal={Neuron},
volume={103},
number={2},
pages={355--371},
year={2019},
publisher={Elsevier}
}We welcome contributions! Please see Contributing to SpikeRNN for guidelines.
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests for new functionality
- Submit a pull request
MIT License - see LICENCE file for details.
- Documentation: Read the Docs
- Issues: GitHub Issues
- Rate Package: rate/
- Spiking Package: spiking/