Compare multiple sorters and consensus based method

With 3 or more spike sorters, the comparison is implemented with a graph-based method. The multiple sorter comparison also allows to clean the output by applying a consensus-based method which only selects spike trains and spikes in agreement with multiple sorters.

Import

import numpy as np
import matplotlib.pyplot as plt

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
First, let’s download a simulated dataset

from the repo ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting = se.read_mearec(local_path)
print(recording)
print(sorting)
MEArecRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5

Then run 3 spike sorters and compare their output.

sorting_MS4 = ss.run_mountainsort4(recording)
sorting_HS = ss.run_herdingspikes(recording)
sorting_TDC = ss.run_tridesclous(recording)
Mountainsort4 use the OLD spikeextractors mapped with NewToOldRecording
# Generating new position and neighbor files from data file
# Not Masking any Channels
# Sampling rate: 32000
# Localization On
# Number of recorded channels: 32
# Analysing frames: 320000; Seconds: 10.0
# Frames before spike in cutout: 10
# Frames after spike in cutout: 58
# tcuts: 42 90
# tInc: 100000
# Detection completed, time taken: 0:00:00.613271
# Time per frame: 0:00:00.001916
# Time per sample: 0:00:00.000060
Loaded 713 spikes.
Fitting dimensionality reduction using all spikes...
...projecting...
...done
Clustering...
Clustering 713 spikes...
number of seeds: 10
seeds/job: 6
using 2 cpus
[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done   2 out of   2 | elapsed:    2.2s finished
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:242: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  unique = np.ones(len(sorted_centers), dtype=np.bool)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:255: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  labels = np.zeros(n_samples, dtype=np.int)
Number of estimated units: 7
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:175: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v1 = distutils.version.LooseVersion(tridesclous_version).version
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:176: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v2 = distutils.version.LooseVersion(self.info['tridesclous_version']).version
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:175: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v1 = distutils.version.LooseVersion(tridesclous_version).version
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:176: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v2 = distutils.version.LooseVersion(self.info['tridesclous_version']).version

Compare multiple spike sorter outputs

mcmp = sc.compare_multiple_sorters(
    sorting_list=[sorting_MS4, sorting_HS, sorting_TDC],
    name_list=['MS4', 'HS', 'TDC'],
    verbose=True,
)
Multicomaprison step 1: pairwise comparison
  Comparing: MS4 and HS
  Comparing: MS4 and TDC
  Comparing: HS and TDC
Multicomparison step 2: make graph
Multicomaprison step 3: clean graph
Removed 0 duplicate nodes
Multicomparison step 4: extract agreement from graph

The multiple sorters comparison internally computes pairwise comparison, that can be accessed as follows:

print(mcmp.comparisons[0].sorting1, mcmp.comparisons[0].sorting2)
# mcmp.comparisons[0].get_mapped_sorting1().get_mapped_unit_ids()
print(mcmp.comparisons[0].get_matching())
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/0.95.0/examples/modules/comparison/plot_2_compare_multiple_sorters.py", line 54, in <module>
    print(mcmp.comparisons[0].sorting1, mcmp.comparisons[0].sorting2)
KeyError: 0
print(mcmp.comparisons[1].sorting1, mcmp.comparisons[1].sorting2)
# mcmp.comparisons[0].get_mapped_sorting1().get_mapped_unit_ids()
print(mcmp.comparisons[0].get_matching())

The global multi comparison can be visualized with this graph

sw.plot_multicomp_graph(mcmp)

We can see that there is a better agreement between tridesclous and  mountainsort (5 units matched), while klusta only has two matched units  with tridesclous, and three with mountainsort.

Consensus-based method

We can pull the units in agreement with different sorters using the get_agreement_sorting() method. This allows to make spike sorting more robust by integrating the output of several algorithms. On the other hand, it might suffer from weak performance of single algorithms.

When extracting the units in agreement, the spike trains are modified so that only the true positive spikes between the comparison with the best match are used.

agr_3 = mcmp.get_agreement_sorting(minimum_agreement_count=3)
print('Units in agreement for all three sorters: ', agr_3.get_unit_ids())
agr_2 = mcmp.get_agreement_sorting(minimum_agreement_count=2)
print('Units in agreement for at least two sorters: ', agr_2.get_unit_ids())
agr_all = mcmp.get_agreement_sorting()

The unit index of the different sorters can also be retrieved from the agreement sorting object (agr_3) property sorter_unit_ids.

print(agr_3.get_property('unit_ids'))
print(agr_3.get_unit_ids())
# take one unit in agreement
unit_id0 = agr_3.get_unit_ids()[0]
sorter_unit_ids = agr_3.get_property('unit_ids')[0]
print(unit_id0, ':', sorter_unit_ids)

Now that we found our unit, we can plot a rasters with the spike trains of the single sorters and the one from the consensus based method. When extracting the agreement sorting, spike trains are cleaned so that only true positives remain from the comparison with the largest agreement are kept. Let’s take a look at the raster plots for the different sorters and the agreement sorter:

st0 = sorting_MS4.get_unit_spike_train(sorter_unit_ids['MS4'])
st1 = sorting_HS.get_unit_spike_train(sorter_unit_ids['HS'])
st2 = sorting_TDC.get_unit_spike_train(sorter_unit_ids['TDC'])
st3 = agr_3.get_unit_spike_train(unit_id0)


fig, ax = plt.subplots()
ax.plot(st0, 0 * np.ones(st0.size), '|')
ax.plot(st1, 1 * np.ones(st1.size), '|')
ax.plot(st2, 2 * np.ones(st2.size), '|')
ax.plot(st3, 3 * np.ones(st3.size), '|')

print('mountainsort4 spike train length', st0.size)
print('herdingspikes spike train length', st1.size)
print('Tridesclous spike train length', st2.size)
print('Agreement spike train length', st3.size)

As we can see, the best match is between Mountainsort and Tridesclous, but only the true positive spikes make up the agreement spike train.

plt.show()

Total running time of the script: ( 1 minutes 1.985 seconds)

Gallery generated by Sphinx-Gallery