Compare two sorters

This example show how to compare the result of two sorters.

Import

import numpy as np
import matplotlib.pyplot as plt

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
First, let’s download a simulated dataset

from the repo ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting = se.read_mearec(local_path)
print(recording)
print(sorting)
MEArecRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5

Then run two spike sorters and compare their output.

sorting_HS = ss.run_herdingspikes(recording)
sorting_TDC = ss.run_tridesclous(recording)
# Generating new position and neighbor files from data file
# Not Masking any Channels
# Sampling rate: 32000
# Localization On
# Number of recorded channels: 32
# Analysing frames: 320000; Seconds: 10.0
# Frames before spike in cutout: 10
# Frames after spike in cutout: 58
# tcuts: 42 90
# tInc: 100000
# Detection completed, time taken: 0:00:00.608594
# Time per frame: 0:00:00.001902
# Time per sample: 0:00:00.000059
Loaded 713 spikes.
Fitting dimensionality reduction using all spikes...
...projecting...
...done
Clustering...
Clustering 713 spikes...
number of seeds: 10
seeds/job: 6
using 2 cpus
[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done   2 out of   2 | elapsed:    2.2s finished
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:242: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  unique = np.ones(len(sorted_centers), dtype=np.bool)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/herdingspikes/clustering/mean_shift_.py:255: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  labels = np.zeros(n_samples, dtype=np.int)
Number of estimated units: 7
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:175: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v1 = distutils.version.LooseVersion(tridesclous_version).version
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:176: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v2 = distutils.version.LooseVersion(self.info['tridesclous_version']).version
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:175: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v1 = distutils.version.LooseVersion(tridesclous_version).version
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/conda/0.95.0/lib/python3.8/site-packages/tridesclous/dataio.py:176: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  v2 = distutils.version.LooseVersion(self.info['tridesclous_version']).version

The compare_two_sorters() function allows us to compare the spike sorting output. It returns a SymmetricSortingComparison object, with methods to inspect the comparison output easily. The comparison matches the units by comparing the agreement between unit spike trains.

Let’s see how to inspect and access this matching.

cmp_HS_TDC = sc.compare_two_sorters(
    sorting1=sorting_HS,
    sorting2=sorting_TDC,
    sorting1_name='HS',
    sorting2_name='TDC',
)

We can check the agreement matrix to inspect the matching.

sw.plot_agreement_matrix(cmp_HS_TDC)
plot 1 compare two sorters
<spikeinterface.widgets._legacy_mpl_widgets.agreementmatrix.AgreementMatrixWidget object at 0x7f29952fd700>

Some useful internal dataframes help to check the match and count   like match_event_count or agreement_scores

print(cmp_HS_TDC.match_event_count)
print(cmp_HS_TDC.agreement_scores)
    0   1   2   3   4   5    6    7    8
0   0   1   0   0   0   0    1  162    3
1   0   0   0   0   0   0  127    1    1
2   0   0   0   0   0   0    2    1  104
3  30   0  51   0   0  48    0    0    0
4   0  53   6   1   0   0    0    2    1
5   0   1   0  50   0   0    0    1    0
6   0   0   0   0  42   0    0    0    0
          0         1         2  ...         6         7         8
0  0.000000  0.004000  0.000000  ...  0.003067  0.704348  0.009836
1  0.000000  0.000000  0.000000  ...  0.984496  0.003125  0.004237
2  0.000000  0.000000  0.000000  ...  0.008621  0.003356  0.936937
3  0.232558  0.000000  0.395349  ...  0.000000  0.000000  0.000000
4  0.000000  0.898305  0.057692  ...  0.000000  0.007968  0.005952
5  0.000000  0.009804  0.000000  ...  0.000000  0.004115  0.000000
6  0.000000  0.000000  0.000000  ...  0.000000  0.000000  0.000000

[7 rows x 9 columns]

In order to check which units were matched, the get_matching  methods can be used. If units are not matched they are listed as -1.

sc_to_tdc, tdc_to_sc = cmp_HS_TDC.get_matching()

print('matching HS to TDC')
print(sc_to_tdc)
print('matching TDC to HS')
print(tdc_to_sc)
matching HS to TDC
0    7.0
1    6.0
2    8.0
3   -1.0
4    1.0
5    3.0
6    4.0
dtype: float64
matching TDC to HS
0   -1.0
1    4.0
2   -1.0
3    5.0
4    6.0
5   -1.0
6    1.0
7    0.0
8    2.0
dtype: float64

The get_unit_spike_train() returns the mapped spike train. We can use it to check the spike times.

matched_ids = sc_to_tdc[sc_to_tdc != -1]

unit_id_HS = matched_ids.index[0]
unit_id_TDC = matched_ids[unit_id_HS]



# check that matched spike trains correspond
st1 = sorting_HS.get_unit_spike_train(unit_id_HS)
st2 = sorting_TDC.get_unit_spike_train(unit_id_TDC)
fig, ax = plt.subplots()
ax.plot(st1, np.zeros(st1.size), '|')
ax.plot(st2, np.ones(st2.size), '|')

plt.show()
plot 1 compare two sorters

Total running time of the script: ( 0 minutes 14.382 seconds)

Gallery generated by Sphinx-Gallery