Getting started tutorial

In this introductory example, you will see how to use the spikeinterface to perform a full electrophysiology analysis. We will first create some simulated data, and we will then perform some pre-processing, run a couple of spike sorting algorithms, inspect and validate the results, export to Phy, and compare spike sorters.

import matplotlib.pyplot as plt

The spikeinterface module by itself import only the spikeinterface.core submodule which is not useful for end user

import spikeinterface

We need to import one by one different submodules separately (preferred). There are 5 modules:

  • extractors : file IO

  • toolkit : processing toolkit for pre-, post-processing, validation, and automatic curation

  • sorters : Python wrappers of spike sorters

  • comparison : comparison of spike sorting output

  • widgets : visualization

import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw

We can also import all submodules at once with this   this internally import core+extractors+toolkit+sorters+comparison+widgets+exporters

This is useful for notebooks but this is a more heavy import because internally many more dependency are imported (scipy/sklearn/networkx/matplotlib/h5py…)

import spikeinterface.full as si

First, let’s download a simulated dataset from the ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data’ repo

Then we can open it. Note that MEArec simulated file contains both “recording” and a “sorting” object.

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting_true = se.read_mearec(local_path)
print(recording)
print(sorting_true)

Out:

[INFO] Cloning dataset to Dataset(/home/docs/spikeinterface_datasets/ephy_testing_data)
[INFO] Attempting to clone from https://gin.g-node.org/NeuralEnsemble/ephy_testing_data to /home/docs/spikeinterface_datasets/ephy_testing_data
[INFO] Start enumerating objects
[INFO] Start counting objects
[INFO] Start compressing objects
[INFO] Start receiving objects
[INFO] Start resolving deltas
[INFO] Completed clone attempts for Dataset(/home/docs/spikeinterface_datasets/ephy_testing_data)
MEArecRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5

recording is a RecordingExtractor object, which extracts information about channel ids, channel locations (if present), the sampling frequency of the recording, and the extracellular traces. sorting_true is a SortingExtractor object, which contains information about spike-sorting related information, including unit ids, spike trains, etc. Since the data are simulated, sorting_true has ground-truth information of the spiking activity of each unit.

Let’s use the widgets module to visualize the traces and the raster plots.

w_ts = sw.plot_timeseries(recording, time_range=(0, 5))
w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5))
  • plot getting started
  • plot getting started

This is how you retrieve info from a RecordingExtractor

channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()
num_seg = recording.get_num_segments()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)

Out:

Channel ids: ['1' '2' '3' '4' '5' '6' '7' '8' '9' '10' '11' '12' '13' '14' '15' '16'
 '17' '18' '19' '20' '21' '22' '23' '24' '25' '26' '27' '28' '29' '30'
 '31' '32']
Sampling frequency: 32000.0
Number of channels: 32
Number of segments: 1

…and a SortingExtractor

num_seg = recording.get_num_segments()
unit_ids = sorting_true.get_unit_ids()
spike_train = sorting_true.get_unit_spike_train(unit_id=unit_ids[0])

print('Number of segments:', num_seg)
print('Unit ids:', unit_ids)
print('Spike train of first unit:', spike_train)

Out:

Number of segments: 1
Unit ids: ['#0' '#1' '#2' '#3' '#4' '#5' '#6' '#7' '#8' '#9']
Spike train of first unit: [  5197   8413  13124  15420  15497  15668  16929  19607  55107  59060
  60958 105193 105569 117082 119243 119326 122293 122877 132413 139498
 147402 147682 148271 149857 165454 170569 174319 176237 183598 192278
 201535 217193 219715 221226 222967 223897 225338 243206 243775 248754
 253184 253308 265132 266197 266662 283149 284716 287592 304025 305286
 310438 310775 318460]

spikeinterface internally uses the probeinterface to handle Probe and ProbeGroup.  So any probe in the probeinterface collections can be download and set to a Recording object. In this case, the MEArec dataset already handles a Probe and we don’t need to set it.

probe = recording.get_probe()
print(probe)

from probeinterface.plotting import plot_probe

plot_probe(probe)
Probe - 32ch

Out:

Probe - 32ch

(<matplotlib.collections.PolyCollection object at 0x7f5a9ddd6d00>, <matplotlib.collections.PolyCollection object at 0x7f5a9bd3c2e0>)

Using the toolkit, you can perform preprocessing on the recordings. Each pre-processing function also returns a RecordingExtractor, which makes it easy to build pipelines. Here, we filter the recording and apply common median reference (CMR). All theses preprocessing steps are “lazy”. The computation is done on demand when we call recording.get_traces(…) or when we save the object to disk.

recording_cmr = recording
recording_f = st.bandpass_filter(recording, freq_min=300, freq_max=6000)
print(recording_f)
recording_cmr = st.common_reference(recording_f, reference='global', operator='median')
print(recording_cmr)

# this computes and saves the recording after applying the preprocessing chain
recording_preprocessed = recording_cmr.save(format='binary')
print(recording_preprocessed)

Out:

BandpassFilterRecording: 32 channels - 1 segments - 32.0kHz - 10.000s
CommonReferenceRecording: 32 channels - 1 segments - 32.0kHz - 10.000s
Use cache_folder=/tmp/spikeinterface_cache/tmptadbb2nn/Q2B8LWAU
write_binary_recording with n_jobs 1  chunk_size None
BinaryRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
  file_paths: ['/tmp/spikeinterface_cache/tmptadbb2nn/Q2B8LWAU/traces_cached_seg0.raw']

Now you are ready to spike sort using the sorters module! Let’s first check which sorters are implemented and which are installed

print('Available sorters', ss.available_sorters())
print('Installed sorters', ss.installed_sorters())

Out:

Available sorters ['combinato', 'hdsort', 'herdingspikes', 'ironclust', 'kilosort', 'kilosort2', 'kilosort2_5', 'kilosort3', 'klusta', 'mountainsort4', 'pykilosort', 'spykingcircus', 'tridesclous', 'waveclus', 'yass']
Installed sorters ['herdingspikes', 'mountainsort4', 'tridesclous']

The ss.installed_sorters() will list the sorters installed in the machine. We can see we have HerdingSpikes and Tridesclous installed. Spike sorters come with a set of parameters that users can change. The available parameters are dictionaries and can be accessed with:

print(ss.get_default_params('herdingspikes'))
print(ss.get_default_params('tridesclous'))

Out:

{'clustering_bandwidth': 5.5, 'clustering_alpha': 5.5, 'clustering_n_jobs': -1, 'clustering_bin_seeding': True, 'clustering_min_bin_freq': 16, 'clustering_subset': None, 'left_cutout_time': 0.3, 'right_cutout_time': 1.8, 'detect_threshold': 20, 'probe_masked_channels': [], 'probe_inner_radius': 70, 'probe_neighbor_radius': 90, 'probe_event_length': 0.26, 'probe_peak_jitter': 0.2, 't_inc': 100000, 'num_com_centers': 1, 'maa': 12, 'ahpthr': 11, 'out_file_name': 'HS2_detected', 'decay_filtering': False, 'save_all': False, 'amp_evaluation_time': 0.4, 'spk_evaluation_time': 1.0, 'pca_ncomponents': 2, 'pca_whiten': True, 'freq_min': 300.0, 'freq_max': 6000.0, 'filter': True, 'pre_scale': True, 'pre_scale_value': 20.0, 'filter_duplicates': True}
{'freq_min': 400.0, 'freq_max': 5000.0, 'detect_sign': -1, 'detect_threshold': 5, 'common_ref_removal': False, 'nested_params': None, 'total_memory': '500M', 'n_jobs_bin': 1}

Let’s run herdingspikes and change one of the parameter, say, the detect_threshold:

sorting_HS = ss.run_herdingspikes(recording=recording_preprocessed, detect_threshold=4)
print(sorting_HS)

Out:

Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI
# Generating new position and neighbor files from data file
# Not Masking any Channels
# Sampling rate: 32000
# Localization On
# Number of recorded channels: 32
# Analysing frames: 320000; Seconds: 10.0
# Frames before spike in cutout: 10
# Frames after spike in cutout: 58
# tcuts: 42 90
# tInc: 100000
# Analysing frames from -42 to 100090  (0.0%)
# Analysing frames from 99958 to 200090  (31.2%)
# Analysing frames from 199958 to 300090  (62.5%)
# Analysing frames from 299958 to 320000  (93.8%)
# Detection completed, time taken: 0:00:00.721043
# Time per frame: 0:00:00.002253
# Time per sample: 0:00:00.000070
Loaded 836 spikes.
Fitting dimensionality reduction using all spikes...
...projecting...
...done
Clustering...
Clustering 836 spikes...
number of seeds: 13
seeds/job: 7
using 2 cpus
[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done   2 out of   2 | elapsed:    1.5s finished
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:242: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  unique = np.ones(len(sorted_centers), dtype=np.bool)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:255: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  labels = np.zeros(n_samples, dtype=np.int)
Number of estimated units: 6
HerdingspikesSortingExtractor: 6 units - 1 segments - 32.0kHz
  file_path: /home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/latest/examples/getting_started/herdingspikes_output/HS2_sorted.hdf5

Alternatively we can pass full dictionary containing the parameters:

other_params = ss.get_default_params('herdingspikes')
other_params['detect_threshold'] = 5

# parameters set by params dictionary
sorting_HS_2 = ss.run_herdingspikes(recording=recording_preprocessed, output_folder="redringspikes_output2",
                                    **other_params)
print(sorting_HS_2)

Out:

Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI
# Generating new position and neighbor files from data file
# Not Masking any Channels
# Sampling rate: 32000
# Localization On
# Number of recorded channels: 32
# Analysing frames: 320000; Seconds: 10.0
# Frames before spike in cutout: 10
# Frames after spike in cutout: 58
# tcuts: 42 90
# tInc: 100000
# Analysing frames from -42 to 100090  (0.0%)
# Analysing frames from 99958 to 200090  (31.2%)
# Analysing frames from 199958 to 300090  (62.5%)
# Analysing frames from 299958 to 320000  (93.8%)
# Detection completed, time taken: 0:00:00.681401
# Time per frame: 0:00:00.002129
# Time per sample: 0:00:00.000067
Loaded 826 spikes.
Fitting dimensionality reduction using all spikes...
...projecting...
...done
Clustering...
Clustering 826 spikes...
number of seeds: 13
seeds/job: 7
using 2 cpus
[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done   2 out of   2 | elapsed:    0.1s finished
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:242: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  unique = np.ones(len(sorted_centers), dtype=np.bool)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:255: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  labels = np.zeros(n_samples, dtype=np.int)
Number of estimated units: 6
HerdingspikesSortingExtractor: 6 units - 1 segments - 32.0kHz
  file_path: /home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/latest/examples/getting_started/redringspikes_output2/HS2_sorted.hdf5

Let’s run tridesclous as well, with default parameters:

sorting_TDC = ss.run_tridesclous(recording=recording_preprocessed)
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/latest/examples/getting_started/plot_getting_started.py", line 161, in <module>
    sorting_TDC = ss.run_tridesclous(recording=recording_preprocessed)
  File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/spikeinterface/sorters/runsorter.py", line 320, in run_tridesclous
    return run_sorter('tridesclous', *args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/spikeinterface/sorters/runsorter.py", line 58, in run_sorter
    sorting = run_sorter_local(sorter_name, recording, output_folder=output_folder,
  File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/spikeinterface/sorters/runsorter.py", line 83, in run_sorter_local
    SorterClass.run_from_folder(output_folder, raise_error, verbose)
  File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/latest/lib/python3.8/site-packages/spikeinterface/sorters/basesorter.py", line 234, in run_from_folder
    raise SpikeSortingError(
spikeinterface.sorters.utils.misc.SpikeSortingError: Spike sorting failed. You can inspect the runtime trace in spikeinterface_log.json

The sorting_HS and sorting_TDC are SortingExtractor objects. We can print the units found using:

print('Units found by herdingspikes:', sorting_HS.get_unit_ids())
print('Units found by tridesclous:', sorting_TDC.get_unit_ids())

spikeinterface provides a efficient way to extractor waveform snippets from paired recording/sorting objects. The WaveformExtractor class samples some spikes (max_spikes_per_unit=500) for each cluster and stores them on disk. These waveforms per cluster are helpful to compute the average waveform, or “template”, for each unit and then to compute, for example, quality metrics.

we_TDC = si.WaveformExtractor.create(recording_preprocessed, sorting_TDC, 'waveforms', remove_if_exists=True)
we_TDC.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we_TDC.run(n_jobs=-1, chunk_size=30000)
print(we_TDC)

unit_id0 = sorting_TDC.unit_ids[0]
wavefroms = we_TDC.get_waveforms(unit_id0)
print(wavefroms.shape)

template = we_TDC.get_template(unit_id0)
print(template.shape)

Once we have the WaveformExtractor object we can post-process, validate, and curate the results. With the toolkit.postprocessing submodule, one can, for example, get waveforms, templates, maximum channels, PCA scores, or export the data to Phy. Phy is a GUI for manual curation of the spike sorting output. To export to phy you can run:

from spikeinterface.exporters import export_to_phy

export_to_phy(we_TDC, './phy_folder_for_TDC',
              compute_pc_features=False, compute_amplitudes=True)

Then you can run the template-gui with: phy template-gui phy/params.py and manually curate the results.

Quality metrics for the spike sorting output are very important to asses the spike sorting performance. The spikeinterface.toolkit.qualitymetrics module implements several quality metrics to assess the goodness of sorted units. Among those, for example, are signal-to-noise ratio, ISI violation ratio, isolation distance, and many more. Theses metrics are built on top of WaveformExtractor class and return a dictionary with the unit ids as keys:

snrs = st.compute_snrs(we_TDC)
print(snrs)
isi_violations_rate, isi_violations_count = st.compute_isi_violations(we_TDC, isi_threshold_ms=1.5)
print(isi_violations_rate)
print(isi_violations_count)

All theses quality metrics can be computed in one shot and returned as a pandas.Dataframe

metrics = st.compute_quality_metrics(we_TDC, metric_names=['snr', 'isi_violation', 'amplitude_cutoff'])
print(metrics)

Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select sorted units with a SNR above a certain threshold:

keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_rate'] < 0.01)
print(keep_mask)

keep_unit_ids = keep_mask[keep_mask].index.values
print(keep_unit_ids)

curated_sorting = sorting_TDC.select_units(keep_unit_ids)
print(curated_sorting)

The final part of this tutorial deals with comparing spike sorting outputs. We can either (1) compare the spike sorting results with the ground-truth sorting sorting_true, (2) compare the output of two (HerdingSpikes and Tridesclous), or (3) compare the output of multiple sorters:

comp_gt_TDC = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC)
comp_TDC_HS = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_HS)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_TDC, sorting_HS],
                                         name_list=['tdc', 'hs'])

When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion matrix

comp_gt_TDC.get_performance()
w_conf = sw.plot_confusion_matrix(comp_gt_TDC)
w_agr = sw.plot_agreement_matrix(comp_gt_TDC)

When comparing two sorters (2), we can see the matching of units between sorters. Units which are not matched has -1 as unit id:

comp_TDC_HS.hungarian_match_12

or the reverse:

comp_TDC_HS.hungarian_match_21

When comparing multiple sorters (3), you can extract a SortingExtractor object with units in agreement between sorters. You can also plot a graph showing how the units are matched between the sorters.

sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2)

print('Units in agreement between Klusta and Mountainsort4:', sorting_agreement.get_unit_ids())

w_multi = sw.plot_multicomp_graph(comp_multi)

plt.show()

Total running time of the script: ( 5 minutes 39.337 seconds)

Gallery generated by Sphinx-Gallery