Postprocessing Tutorial

Spike sorters generally output a set of units with corresponding spike trains. The toolkit.postprocessing submodule allows to combine the RecordingExtractor and the sorted SortingExtractor objects to perform further postprocessing.

import matplotlib.pylab as plt

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.toolkit as st

First, let’s download a simulated dataset from the repo ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording = se.MEArecRecordingExtractor(local_path)
sorting = se.MEArecSortingExtractor(local_path)
print(recording)
print(sorting)

Out:

MEArecRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5

Assuming the sorting is the output of a spike sorter, the postprocessing module allows to extract all relevant information from the paired recording-sorting.

Compute spike waveforms

Waveforms are extracted with the WaveformExtractor or directly with the extract_waveforms function (which returns a WaveformExtractor object):

folder = 'waveforms_mearec'
we = si.extract_waveforms(recording, sorting, folder,
                          load_if_exists=True,
                          ms_before=1, ms_after=2., max_spikes_per_unit=500,
                          n_jobs=1, chunk_size=30000)
print(we)

Out:

WaveformExtractor: 32 channels - 10 units - 1 segments
  before:32 after64 n_per_units: 500

Let’s plot the waveforms of units [0, 1, 2] on channel 8

colors = ['Olive', 'Teal', 'Fuchsia']

fig, ax = plt.subplots()
for i, unit_id in enumerate(sorting.unit_ids[:3]):
    wf = we.get_waveforms(unit_id)
    color = colors[i]
    ax.plot(wf[:, :, 8].T, color=color, lw=0.3)
plot 2 postprocessing

Compute unit templates

Similarly to waveforms, templates - average waveforms - can be easily retrieved from the WaveformExtractor object:

fig, ax = plt.subplots()
for i, unit_id in enumerate(sorting.unit_ids[:3]):
    template = we.get_template(unit_id)
    color = colors[i]
    ax.plot(template[:, 8].T, color=color, lw=3)
plot 2 postprocessing

Compute unit maximum channel

In a similar way, one can get the recording channel with the ‘extremum’ signal (minimum or maximum). The code:get_template_extremum_channel outputs a dictionary unit_ids as keys and channel_ids as values:

extremum_channels_ids = st.get_template_extremum_channel(we, peak_sign='neg')
print(extremum_channels_ids)

Out:

{'#0': '6', '#1': '10', '#2': '20', '#3': '3', '#4': '14', '#5': '12', '#6': '25', '#7': '22', '#8': '23', '#9': '7'}

Compute principal components (aka PCs)

Computing PCA scores for each waveforms is very common for many applications, including unsupervised validation of the spike sorting performance.

There are different ways to compute PC scores from waveforms:
  • “concatenated”: all waveforms are concatenated and a single PCA model is computed (channel information is lost)

  • “by_channel_global”: PCA is computed from a subset of all waveforms and applied independently on each channel

  • “by_channel_local”: PCA is computed and applied to each channel separately

In SI, we can compute PC scores with the compute_principal_components function (which returns a WaveformPrincipalComponent object). The pc scores for a unit are retrieved with the get_components function and the shape of the pc scores is (n_spikes, n_components, n_channels). Here, we compute PC scores and plot the first and second components of channel 8:

pc = st.compute_principal_components(we, load_if_exists=True,
                                     n_components=3, mode='by_channel_local')
print(pc)

fig, ax = plt.subplots()
for i, unit_id in enumerate(sorting.unit_ids[:3]):
    comp = pc.get_components(unit_id)
    print(comp.shape)
    color = colors[i]
    ax.scatter(comp[:, 0, 8], comp[:, 1, 8], color=color)
plot 2 postprocessing

Out:

WaveformPrincipalComponent: 32 channels - 1 segments
  mode:by_channel_local n_components:3
(53, 3, 32)
(50, 3, 32)
(43, 3, 32)

Note that PC scores for all units can be retrieved at once with the get_all_components() function:

all_labels, all_components = pc.get_all_components()
print(all_labels[:40])
print(all_labels.shape)
print(all_components.shape)

cmap = plt.get_cmap('Dark2', len(sorting.unit_ids))

fig, ax = plt.subplots()
for i, unit_id in enumerate(sorting.unit_ids):
    mask = all_labels == unit_id
    comp = all_components[mask, :, :]
    ax.scatter(comp[:, 0, 8], comp[:, 1, 8], color=cmap(i))
plot 2 postprocessing

Out:

['#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0'
 '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0'
 '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0' '#0']
(745,)
(745, 3, 32)
Export sorted data to Phy for manual curation

 @alessio : please remove this cell when you read it   export_to_phy is not anymore in toolkit but in exporter

We won’t make a tutorial for it because it is super slow. Only statics docs. Finally, it is common to visualize and manually curate the data after spike sorting. In order to do so, we interface with the Phy GUI (https://phy-contrib.readthedocs.io/en/latest/template-gui/).

First, we need to export the data to the phy format:

# output_folder = 'mearec_exported_to_phy'
# st.export_to_phy(we,
#                  compute_pc_features=False, compute_amplitudes=True,
#                  remove_if_exists=True)

To run phy you can then run (from terminal): phy template-gui mearec_exported_to_phy/params.py

Or from a notebook:  !phy template-gui mearec_exported_to_phy/params.py

After manual curation you can load back the curated data using the PhySortingExtractor:

plt.show()

Total running time of the script: ( 0 minutes 2.117 seconds)

Gallery generated by Sphinx-Gallery