Note
Click here to download the full example code
Getting started tutorial¶
In this introductory example, you will see how to use the spikeinterface
to perform a full electrophysiology analysis.
We will first create some simulated data, and we will then perform some pre-processing, run a couple of spike sorting
algorithms, inspect and validate the results, export to Phy, and compare spike sorters.
import matplotlib.pyplot as plt
The spikeinterface module by itself import only the spikeinterface.core submodule which is not useful for end user
import spikeinterface
We need to import one by one different submodules separately (preferred). There are 5 modules:
extractors
: file IOtoolkit
: processing toolkit for pre-, post-processing, validation, and automatic curationsorters
: Python wrappers of spike sorterscomparison
: comparison of spike sorting outputwidgets
: visualization
import spikeinterface as si # import core only
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
We can also import all submodules at once with this this internally import core+extractors+toolkit+sorters+comparison+widgets+exporters
This is useful for notebooks but this is a more heavy import because internally many more dependency are imported (scipy/sklearn/networkx/matplotlib/h5py…)
import spikeinterface.full as si
First, let’s download a simulated dataset from the ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data’ repo
Then we can open it. Note that MEArec simulated file contains both “recording” and a “sorting” object.
local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting_true = se.read_mearec(local_path)
print(recording)
print(sorting_true)
Out:
[INFO] Cloning dataset to Dataset(/home/docs/spikeinterface_datasets/ephy_testing_data)
[INFO] Attempting to clone from https://gin.g-node.org/NeuralEnsemble/ephy_testing_data to /home/docs/spikeinterface_datasets/ephy_testing_data
[INFO] Start enumerating objects
[INFO] Start counting objects
[INFO] Start compressing objects
[INFO] Start receiving objects
[INFO] Start resolving deltas
[INFO] Completed clone attempts for Dataset(/home/docs/spikeinterface_datasets/ephy_testing_data)
MEArecRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
recording
is a RecordingExtractor
object, which extracts information about channel ids, channel locations
(if present), the sampling frequency of the recording, and the extracellular traces. sorting_true
is a
SortingExtractor
object, which contains information about spike-sorting related information, including unit ids,
spike trains, etc. Since the data are simulated, sorting_true
has ground-truth information of the spiking
activity of each unit.
Let’s use the widgets
module to visualize the traces and the raster plots.
w_ts = sw.plot_timeseries(recording, time_range=(0, 5))
w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5))
This is how you retrieve info from a RecordingExtractor
…
channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()
num_seg = recording.get_num_segments()
print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)
Out:
Channel ids: ['1' '2' '3' '4' '5' '6' '7' '8' '9' '10' '11' '12' '13' '14' '15' '16'
'17' '18' '19' '20' '21' '22' '23' '24' '25' '26' '27' '28' '29' '30'
'31' '32']
Sampling frequency: 32000.0
Number of channels: 32
Number of segments: 1
…and a SortingExtractor
num_seg = recording.get_num_segments()
unit_ids = sorting_true.get_unit_ids()
spike_train = sorting_true.get_unit_spike_train(unit_id=unit_ids[0])
print('Number of segments:', num_seg)
print('Unit ids:', unit_ids)
print('Spike train of first unit:', spike_train)
Out:
Number of segments: 1
Unit ids: ['#0' '#1' '#2' '#3' '#4' '#5' '#6' '#7' '#8' '#9']
Spike train of first unit: [ 5197 8413 13124 15420 15497 15668 16929 19607 55107 59060
60958 105193 105569 117082 119243 119326 122293 122877 132413 139498
147402 147682 148271 149857 165454 170569 174319 176237 183598 192278
201535 217193 219715 221226 222967 223897 225338 243206 243775 248754
253184 253308 265132 266197 266662 283149 284716 287592 304025 305286
310438 310775 318460]
spikeinterface
internally uses the probeinterface
to handle Probe and ProbeGroup.
So any probe in the probeinterface collections can be download
and set to a Recording object.
In this case, the MEArec dataset already handles a Probe and we don’t need to set it.
probe = recording.get_probe()
print(probe)
from probeinterface.plotting import plot_probe
plot_probe(probe)
Out:
Probe - 32ch
(<matplotlib.collections.PolyCollection object at 0x7eff45c1ad30>, <matplotlib.collections.PolyCollection object at 0x7eff43b80310>)
Using the toolkit
, you can perform preprocessing on the recordings.
Each pre-processing function also returns a RecordingExtractor
,
which makes it easy to build pipelines. Here, we filter the recording and
apply common median reference (CMR).
All theses preprocessing steps are “lazy”. The computation is done on demand when we call
recording.get_traces(…) or when we save the object to disk.
recording_cmr = recording
recording_f = st.bandpass_filter(recording, freq_min=300, freq_max=6000)
print(recording_f)
recording_cmr = st.common_reference(recording_f, reference='global', operator='median')
print(recording_cmr)
# this computes and saves the recording after applying the preprocessing chain
recording_preprocessed = recording_cmr.save(format='binary')
print(recording_preprocessed)
Out:
BandpassFilterRecording: 32 channels - 1 segments - 32.0kHz - 10.000s
CommonReferenceRecording: 32 channels - 1 segments - 32.0kHz - 10.000s
Use cache_folder=/tmp/spikeinterface_cache/tmp3k4swz26/93VJX6BU
write_binary_recording with n_jobs 1 chunk_size None
BinaryRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
file_paths: ['/tmp/spikeinterface_cache/tmp3k4swz26/93VJX6BU/traces_cached_seg0.raw']
Now you are ready to spike sort using the sorters
module!
Let’s first check which sorters are implemented and which are installed
print('Available sorters', ss.available_sorters())
print('Installed sorters', ss.installed_sorters())
Out:
Available sorters ['combinato', 'hdsort', 'herdingspikes', 'ironclust', 'kilosort', 'kilosort2', 'kilosort2_5', 'kilosort3', 'klusta', 'mountainsort4', 'pykilosort', 'spykingcircus', 'tridesclous', 'waveclus', 'yass']
Installed sorters ['herdingspikes', 'mountainsort4', 'tridesclous']
The ss.installed_sorters()
will list the sorters installed in the machine.
We can see we have HerdingSpikes and Tridesclous installed.
Spike sorters come with a set of parameters that users can change.
The available parameters are dictionaries and can be accessed with:
print(ss.get_default_params('herdingspikes'))
print(ss.get_default_params('tridesclous'))
Out:
{'clustering_bandwidth': 5.5, 'clustering_alpha': 5.5, 'clustering_n_jobs': -1, 'clustering_bin_seeding': True, 'clustering_min_bin_freq': 16, 'clustering_subset': None, 'left_cutout_time': 0.3, 'right_cutout_time': 1.8, 'detect_threshold': 20, 'probe_masked_channels': [], 'probe_inner_radius': 70, 'probe_neighbor_radius': 90, 'probe_event_length': 0.26, 'probe_peak_jitter': 0.2, 't_inc': 100000, 'num_com_centers': 1, 'maa': 12, 'ahpthr': 11, 'out_file_name': 'HS2_detected', 'decay_filtering': False, 'save_all': False, 'amp_evaluation_time': 0.4, 'spk_evaluation_time': 1.0, 'pca_ncomponents': 2, 'pca_whiten': True, 'freq_min': 300.0, 'freq_max': 6000.0, 'filter': True, 'pre_scale': True, 'pre_scale_value': 20.0, 'filter_duplicates': True}
{'freq_min': 400.0, 'freq_max': 5000.0, 'detect_sign': -1, 'detect_threshold': 5, 'common_ref_removal': False, 'nested_params': None, 'total_memory': '500M', 'n_jobs_bin': 1}
Let’s run herdingspikes and change one of the parameter, say, the detect_threshold:
sorting_HS = ss.run_herdingspikes(recording=recording_preprocessed, detect_threshold=4)
print(sorting_HS)
Out:
Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI
# Generating new position and neighbor files from data file
# Not Masking any Channels
# Sampling rate: 32000
# Localization On
# Number of recorded channels: 32
# Analysing frames: 320000; Seconds: 10.0
# Frames before spike in cutout: 10
# Frames after spike in cutout: 58
# tcuts: 42 90
# tInc: 100000
# Analysing frames from -42 to 100090 (0.0%)
# Analysing frames from 99958 to 200090 (31.2%)
# Analysing frames from 199958 to 300090 (62.5%)
# Analysing frames from 299958 to 320000 (93.8%)
# Detection completed, time taken: 0:00:00.714063
# Time per frame: 0:00:00.002231
# Time per sample: 0:00:00.000070
Loaded 836 spikes.
Fitting dimensionality reduction using all spikes...
...projecting...
...done
Clustering...
Clustering 836 spikes...
number of seeds: 13
seeds/job: 7
using 2 cpus
[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done 2 out of 2 | elapsed: 1.6s finished
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:242: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
unique = np.ones(len(sorted_centers), dtype=np.bool)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:255: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
labels = np.zeros(n_samples, dtype=np.int)
Number of estimated units: 6
HerdingspikesSortingExtractor: 6 units - 1 segments - 32.0kHz
file_path: /home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/0.91.0/examples/getting_started/herdingspikes_output/HS2_sorted.hdf5
Alternatively we can pass full dictionary containing the parameters:
other_params = ss.get_default_params('herdingspikes')
other_params['detect_threshold'] = 5
# parameters set by params dictionary
sorting_HS_2 = ss.run_herdingspikes(recording=recording_preprocessed, output_folder="redringspikes_output2",
**other_params)
print(sorting_HS_2)
Out:
Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI
# Generating new position and neighbor files from data file
# Not Masking any Channels
# Sampling rate: 32000
# Localization On
# Number of recorded channels: 32
# Analysing frames: 320000; Seconds: 10.0
# Frames before spike in cutout: 10
# Frames after spike in cutout: 58
# tcuts: 42 90
# tInc: 100000
# Analysing frames from -42 to 100090 (0.0%)
# Analysing frames from 99958 to 200090 (31.2%)
# Analysing frames from 199958 to 300090 (62.5%)
# Analysing frames from 299958 to 320000 (93.8%)
# Detection completed, time taken: 0:00:00.676818
# Time per frame: 0:00:00.002115
# Time per sample: 0:00:00.000066
Loaded 826 spikes.
Fitting dimensionality reduction using all spikes...
...projecting...
...done
Clustering...
Clustering 826 spikes...
number of seeds: 13
seeds/job: 7
using 2 cpus
[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done 2 out of 2 | elapsed: 0.1s finished
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:242: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
unique = np.ones(len(sorted_centers), dtype=np.bool)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:255: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
labels = np.zeros(n_samples, dtype=np.int)
Number of estimated units: 6
HerdingspikesSortingExtractor: 6 units - 1 segments - 32.0kHz
file_path: /home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/0.91.0/examples/getting_started/redringspikes_output2/HS2_sorted.hdf5
Let’s run tridesclous as well, with default parameters:
sorting_TDC = ss.run_tridesclous(recording=recording_preprocessed)
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/0.91.0/examples/getting_started/plot_getting_started.py", line 161, in <module>
sorting_TDC = ss.run_tridesclous(recording=recording_preprocessed)
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/spikeinterface/sorters/runsorter.py", line 320, in run_tridesclous
return run_sorter('tridesclous', *args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/spikeinterface/sorters/runsorter.py", line 58, in run_sorter
sorting = run_sorter_local(sorter_name, recording, output_folder=output_folder,
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/spikeinterface/sorters/runsorter.py", line 83, in run_sorter_local
SorterClass.run_from_folder(output_folder, raise_error, verbose)
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.91.0/lib/python3.8/site-packages/spikeinterface/sorters/basesorter.py", line 234, in run_from_folder
raise SpikeSortingError(
spikeinterface.sorters.utils.misc.SpikeSortingError: Spike sorting failed. You can inspect the runtime trace in spikeinterface_log.json
The sorting_HS
and sorting_TDC
are SortingExtractor
objects. We can print the units found using:
print('Units found by herdingspikes:', sorting_HS.get_unit_ids())
print('Units found by tridesclous:', sorting_TDC.get_unit_ids())
spikeinterface
provides a efficient way to extractor waveform snippets from paired recording/sorting objects.
The WaveformExtractor
class samples some spikes (max_spikes_per_unit=500
) for each cluster and stores
them on disk. These waveforms per cluster are helpful to compute the average waveform, or “template”, for each unit
and then to compute, for example, quality metrics.
we_TDC = si.WaveformExtractor.create(recording_preprocessed, sorting_TDC, 'waveforms', remove_if_exists=True)
we_TDC.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we_TDC.run(n_jobs=-1, chunk_size=30000)
print(we_TDC)
unit_id0 = sorting_TDC.unit_ids[0]
wavefroms = we_TDC.get_waveforms(unit_id0)
print(wavefroms.shape)
template = we_TDC.get_template(unit_id0)
print(template.shape)
Once we have the WaveformExtractor object
we can post-process, validate, and curate the results. With
the toolkit.postprocessing
submodule, one can, for example,
get waveforms, templates, maximum channels, PCA scores, or export the data
to Phy. Phy is a GUI for manual
curation of the spike sorting output. To export to phy you can run:
from spikeinterface.exporters import export_to_phy
export_to_phy(we_TDC, './phy_folder_for_TDC',
compute_pc_features=False, compute_amplitudes=True)
Then you can run the template-gui with: phy template-gui phy/params.py
and manually curate the results.
Quality metrics for the spike sorting output are very important to asses the spike sorting performance.
The spikeinterface.toolkit.qualitymetrics
module implements several quality metrics
to assess the goodness of sorted units. Among those, for example,
are signal-to-noise ratio, ISI violation ratio, isolation distance, and many more.
Theses metrics are built on top of WaveformExtractor class and return a dictionary with the unit ids as keys:
snrs = st.compute_snrs(we_TDC)
print(snrs)
isi_violations_rate, isi_violations_count = st.compute_isi_violations(we_TDC, isi_threshold_ms=1.5)
print(isi_violations_rate)
print(isi_violations_count)
All theses quality metrics can be computed in one shot and returned as
a pandas.Dataframe
metrics = st.compute_quality_metrics(we_TDC, metric_names=['snr', 'isi_violation', 'amplitude_cutoff'])
print(metrics)
Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select sorted units with a SNR above a certain threshold:
keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_rate'] < 0.01)
print(keep_mask)
keep_unit_ids = keep_mask[keep_mask].index.values
print(keep_unit_ids)
curated_sorting = sorting_TDC.select_units(keep_unit_ids)
print(curated_sorting)
The final part of this tutorial deals with comparing spike sorting outputs.
We can either (1) compare the spike sorting results with the ground-truth
sorting sorting_true
, (2) compare the output of two (HerdingSpikes
and Tridesclous), or (3) compare the output of multiple sorters:
comp_gt_TDC = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC)
comp_TDC_HS = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_HS)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_TDC, sorting_HS],
name_list=['tdc', 'hs'])
When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion matrix
comp_gt_TDC.get_performance()
w_conf = sw.plot_confusion_matrix(comp_gt_TDC)
w_agr = sw.plot_agreement_matrix(comp_gt_TDC)
When comparing two sorters (2), we can see the matching of units between sorters. Units which are not matched has -1 as unit id:
comp_TDC_HS.hungarian_match_12
or the reverse:
comp_TDC_HS.hungarian_match_21
When comparing multiple sorters (3), you can extract a SortingExtractor
object with units in agreement
between sorters. You can also plot a graph showing how the units are matched between the sorters.
sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2)
print('Units in agreement between Klusta and Mountainsort4:', sorting_agreement.get_unit_ids())
w_multi = sw.plot_multicomp_graph(comp_multi)
plt.show()
Total running time of the script: ( 3 minutes 12.563 seconds)