Curation Tutorial

After spike sorting and computing quality metrics, you can automatically curate the spike sorting output using the quality metrics that you have calculated.

Import the modules and/or functions necessary from spikeinterface

import spikeinterface.core as si
import spikeinterface.extractors as se

from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import compute_quality_metrics

Let’s download a simulated dataset from the repo ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data

Let’s imagine that the ground-truth sorting is in fact the output of a sorter.

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting = se.read_mearec(file_path=local_path)
print(recording)
print(sorting)
MEArecRecordingExtractor: 32 channels - 32.0kHz - 1 segments - 320,000 samples - 10.00s
                          float32 dtype - 39.06 MiB
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5

First, we extract waveforms (to be saved in the folder ‘wfs_mearec’) and compute their PC (principal component) scores:

we = si.extract_waveforms(recording=recording,
                          sorting=sorting,
                          folder='wfs_mearec',
                          ms_before=1,
                          ms_after=2.,
                          max_spikes_per_unit=500,
                          n_jobs=1,
                          chunk_size=30000)
print(we)

pc = compute_principal_components(we, load_if_exists=True, n_components=3, mode='by_channel_local')
extract waveforms shared_memory multi buffer:   0%|          | 0/11 [00:00<?, ?it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 11/11 [00:00<00:00, 633.41it/s]

extract waveforms memmap multi buffer:   0%|          | 0/11 [00:00<?, ?it/s]
extract waveforms memmap multi buffer: 100%|##########| 11/11 [00:00<00:00, 157.02it/s]
WaveformExtractor: 32 channels - 10 units - 1 segments
  before:32 after:64 n_per_units:500 - sparse

Fitting PCA:   0%|          | 0/10 [00:00<?, ?it/s]
Fitting PCA:  30%|███       | 3/10 [00:00<00:00, 29.16it/s]
Fitting PCA:  60%|██████    | 6/10 [00:00<00:00, 28.75it/s]
Fitting PCA:  90%|█████████ | 9/10 [00:00<00:00, 24.71it/s]
Fitting PCA: 100%|██████████| 10/10 [00:00<00:00, 23.01it/s]

Projecting waveforms:   0%|          | 0/10 [00:00<?, ?it/s]
Projecting waveforms: 100%|██████████| 10/10 [00:00<00:00, 206.40it/s]

Then we compute some quality metrics:

metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'nearest_neighbor'])
print(metrics)
Computing PCA metrics:   0%|          | 0/10 [00:00<?, ?it/s]
Computing PCA metrics:  60%|██████    | 6/10 [00:00<00:00, 53.07it/s]
Computing PCA metrics: 100%|██████████| 10/10 [00:00<00:00, 66.35it/s]
    isi_violations_ratio  isi_violations_count  ...  nn_hit_rate  nn_miss_rate
#0                   0.0                   0.0  ...     1.000000      0.001289
#1                   0.0                   0.0  ...     0.990000      0.000746
#2                   0.0                   0.0  ...     0.976744      0.005117
#3                   0.0                   0.0  ...     1.000000      0.000000
#4                   0.0                   0.0  ...     0.989583      0.001012
#5                   0.0                   0.0  ...     0.993243      0.002660
#6                   0.0                   0.0  ...     0.995098      0.000000
#7                   0.0                   0.0  ...     0.986364      0.010753
#8                   0.0                   0.0  ...     0.994845      0.001506
#9                   0.0                   0.0  ...     0.996124      0.003348

[10 rows x 5 columns]

We can now threshold each quality metric and select units based on some rules.

The easiest and most intuitive way is to use boolean masking with a dataframe.

Then create a list of unit ids that we want to keep

keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_ratio'] < 0.2) & (metrics['nn_hit_rate'] > 0.90)
print(keep_mask)

keep_unit_ids = keep_mask[keep_mask].index.values
keep_unit_ids = [unit_id for unit_id in keep_unit_ids]
print(keep_unit_ids)
#0     True
#1     True
#2     True
#3     True
#4    False
#5    False
#6     True
#7    False
#8     True
#9     True
dtype: bool
['#0', '#1', '#2', '#3', '#6', '#8', '#9']

And now let’s create a sorting that contains only curated units and save it, for example to an NPZ file.

curated_sorting = sorting.select_units(keep_unit_ids)
print(curated_sorting)

se.NpzSortingExtractor.write_sorting(sorting=curated_sorting, save_path='curated_sorting.npz')
UnitsSelectionSorting: 7 units - 1 segments - 32.0kHz

Total running time of the script: (0 minutes 1.254 seconds)

Gallery generated by Sphinx-Gallery