Compare two sorters

This example show how to compare the result of two sorters.

Import

import numpy as np
import matplotlib.pyplot as plt

import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw

First, let’s create a toy example:

recording, sorting = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)

Then run two spike sorters and compare their ouput.

sorting_KL = ss.run_klusta(recording)
sorting_MS4 = ss.run_mountainsort4(recording)

Out:

RUNNING SHELL SCRIPT: /home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/0.13.0/examples/modules/comparison/klusta_output/run_klusta.sh
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/0.13.0/doc/sources/spikesorters/spikesorters/basesorter.py:158: ResourceWarning: unclosed file <_io.TextIOWrapper name=63 encoding='UTF-8'>
  self._run(recording, self.output_folders[i])
Warning! The recording is already filtered, but Mountainsort4 filter is enabled. You can disable filters by setting 'filter' parameter to False

The compare_two_sorters function allows us to compare the spike  sorting output. It returns a SortingComparison object, with methods  to inspect the comparison output easily. The comparison matches the  units by comparing the agreement between unit spike trains.

 Let’s see how to inspect and access this matching.

cmp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4,
                                               sorting1_name='klusta', sorting2_name='ms4')

We can check the agreement matrix to inspect the matching.

sw.plot_agreement_matrix(cmp_KL_MS4)
plot 1 compare two sorters

Out:

<spikewidgets.widgets.agreementmatrixwidget.agreementmatrixwidget.AgreementMatrixWidget object at 0x7f3dfc30a160>

Some useful internal dataframes help to check the match and count   like match_event_count or agreement_scores

print(cmp_KL_MS4.match_event_count)
print(cmp_KL_MS4.agreement_scores)

Out:

   1   2   3   4  5   6   7  8   9
0  0  22  25   1  0   1  22  2   3
2  0   0   0   0  0   0   0  0  27
3  0   0   0   0  0  11   0  0   2
4  0   0   0   8  0   0   0  0   0
5  0   0   0   0  1   0   0  0   0
6  0   0   0  16  0   0   0  0   0
     1         2         3         4  ...         6         7         8         9
0  0.0  0.289474  0.328947  0.006579  ...  0.008621  0.289474  0.005865  0.028302
2  0.0  0.000000  0.000000  0.000000  ...  0.000000  0.000000  0.000000  0.818182
3  0.0  0.000000  0.000000  0.000000  ...  0.255814  0.000000  0.000000  0.045455
4  0.0  0.000000  0.000000  0.103896  ...  0.000000  0.000000  0.000000  0.000000
5  0.0  0.000000  0.000000  0.000000  ...  0.000000  0.000000  0.000000  0.000000
6  0.0  0.000000  0.000000  0.207792  ...  0.000000  0.000000  0.000000  0.000000

[6 rows x 9 columns]

In order to check which units were matched, the get_mapped_sorting  methods can be used. If units are not matched they are listed as -1.

# units matched to klusta units
mapped_sorting_klusta = cmp_KL_MS4.get_mapped_sorting1()
print('Klusta units:', sorting_KL.get_unit_ids())
print('Klusta mapped units:', mapped_sorting_klusta.get_mapped_unit_ids())

# units matched to ms4 units
mapped_sorting_ms4 = cmp_KL_MS4.get_mapped_sorting2()
print('Mountainsort units:',sorting_MS4.get_unit_ids())
print('Mountainsort mapped units:',mapped_sorting_ms4.get_mapped_unit_ids())

Out:

Klusta units: [0, 2, 3, 4, 5, 6]
Klusta mapped units: [-1, 9, -1, -1, -1, -1]
Mountainsort units: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Mountainsort mapped units: [-1, -1, -1, -1, -1, -1, -1, -1, 2]

The :code:get_unit_spike_train` returns the mapped spike train. We can use  it to check the spike times.

# find a unit from KL that have a match
ind  = np.where(np.array(mapped_sorting_klusta.get_mapped_unit_ids())!=-1)[0][0]
u1 = sorting_KL.get_unit_ids()[ind]
print(ind, u1)

# check that matched spike trains correspond
st1 = sorting_KL.get_unit_spike_train(u1)
st2 = mapped_sorting_klusta.get_unit_spike_train(u1)
fig, ax = plt.subplots()
ax.plot(st1, np.zeros(st1.size), '|')
ax.plot(st2, np.ones(st2.size), '|')

plt.show()
plot 1 compare two sorters

Out:

1 2

Total running time of the script: ( 0 minutes 3.869 seconds)

Gallery generated by Sphinx-Gallery