Core module

Overview

The spikeinterface.core module provides the basic classes and tools of the SpikeInterface ecosystem.

Several Base classes are implemented here and inherited throughout the SI code-base. The core classes are: BaseRecording (for raw data), BaseSorting (for spike-sorted data), and WaveformExtarctor (for waveform extraction and postprocessing).

There are additional classes to allow to retrieve events (BaseEvent) and to handle unsorted waveform cutouts, or snippets, which are recorded by some acquisition systems (BaseSnippets).

All classes support:
  • metadata handling

  • data on-demand (lazy loading)

  • multiple segments, where each segment is a contiguous piece of data (recording, sorting, events).

Import rules

Importing the SpikeInterface module

import spikeinterface as si

will only import the core module. Other submodules must be imported separately:

import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

A second option is to import the SpikeInterface package in full mode:

import spikeinterface.full as si

This import statement will import all of SpikeInterface modules as a flattened module. Note that importing spikeinterface.full will take a few extra seconds, because some modules use just-in-time numba compilation performed at the time of import. We recommend this approach to advanced users, since it requires a deeper knowledge of the API.

Recording

The BaseRecording class serves as basis for all Recording classes. It interfaces with the raw traces and has the following features:

  • retrieve raw and scaled traces from each segment

  • keep info about channel_ids VS channel indices

  • handle probe information

  • store channel properties

  • store object annotations

  • enable grouping, splitting, and slicing

  • handle time information

Here we assume recording is a BaseRecording object with 16 channels:

channel_ids = recording.channel_ids
num_channels = recording.get_num_channels()
sampling_frequency = recording.sampling_frequency

# get number of samples/duration
num_samples_segment = recording.get_num_samples(segment_index=0)
### NOTE ###
# 'segment_index' is required for multi-segment objects
num_total_samples = recording.get_total_samples()
total_duration = recording.get_total_duration()

# retrieve raw traces between frames 100 and 200
traces = recording.get_traces(start_frame=100, end_frame=200, segment_index=0)
# retrieve raw traces only for the first 4 of the channels
traces_slice = recording.get_traces(start_frame=100, end_frame=200, segment_index=0,
                                    channel_ids=channel_ids[:4])
# retrieve traces after scaling to uV
# (requires 'gain_to_uV' and 'offset_to_uV' properties)
traces_uV = recording.get_traces(start_frame=100, end_frame=200, segment_index=0,
                                 return_scaled=True)
# set/get a new channel property (e.g. "quality")
recording.set_property(key="quality", values=["good"] * num_channels)
quality_values = recording.get_property("quality")
# get all available properties
property_keys = recording.get_property_keys()

# set/get an annotation
recording.annotate(date="Recording acquired today")
recording.get_annotation(key="date")

# get new recording with the first 10s of the traces
recording_slice_frames = recording.frame_slice(start_frame=0,
                                               end_frame=int(10*sampling_frequency))
# get new recording with the first 4 channels
recording_slice_chans = recording.channel_slice(channel_ids=channel_ids[:4])
# remove last two channels
recording_rm_chans = recording.remove_channels(channel_ids=channel_ids[-2:])

# set channel grouping (assume we have 4 groups of 4 channels, e.g. tetrodes)
groups = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4
recording.set_channel_groups(groups)
# split by property
recording_by_group = recording.split_by("group")
# 'recording_by_group' is a dict with group as keys (0,1,2,3) and channel
# sliced recordings as values

# set times (for synchronization) - assume out times start at 300 seconds
timestamps = np.arange(num_samples) / sampling_frequency + 300
recording.set_times(timestamps, segment_index=0)

Sorting

The BaseSorting class serves as basis for all Sorting classes. It interfaces with a spike-sorted output and has the following features:

  • retrieve spike trains of different units

  • keep info about unit_ids VS unit indices

  • store channel properties

  • store object annotations

  • enable selection of sub-units

  • handle time information

Here we assume sorting is a BaseSorting object with 10 units:

unit_ids = sorting.channel_ids
num_channels = sorting.get_num_units()
sampling_frequency = sorting.sampling_frequency

# retrieve spike trains for a unit (returned as sample indices)
unit0 = unit_ids[0]
spike_train = sorting.get_unit_spike_train(unit_id=unit0, segment_index=0)
# retrieve spikes between 100 and 200
spike_train_slice = sorting.get_unit_spike_train(unit_id=unit0,
                                                 start_frame=100, end_frame=200,
                                                 segment_index=0)
### NOTE ###
# 'segment_index' is required for multi-segment objects

# set/get a new unit property (e.g. "quality")
sorting.set_property(key="quality", values=["good"] * num_units)
quality_values = sorting.get_property("quality")
# get all available properties
property_keys = sorting.get_property_keys()

# set/get an annotation
sorting.annotate(date="Spike sorted today")
sorting.get_annotation(key="date")

# get new sorting with the first 10s of spike trains
sorting_slice_frames = sorting.frame_slice(start_frame=0,
                                           end_frame=int(10*sampling_frequency))
# get new sorting with the first 4 units
sorting_select_units = sorting.select_units(unit_ids=unit_ids[:4])

# register 'recording' from previous and get spike trains in seconds
sorting.register_recording(recording)
spike_train_s = sorting.get_unit_spike_train(unit_id=unit0, segment_index=0,
                                             return_times=True)
### NOTE ###
# When running spike sorting in SpikeInterface, the recording is  automatically registered. If
# times are not set, the samples are divided by the sampling frequency

WaveformExtractor

The WaveformExtractor class is the core object to combine a BaseRecording and a BaseSorting object. Waveforms are very important for additional analysis, and the basis of several postprocessing and quality metrics computations.

The WaveformExtractor allows to:

  • extract and waveforms

  • sub-sample spikes for waveform extraction

  • compute templates (i.e. average extracellular waveforms) with different modes

  • save waveforms in a folder (in numpy / Zarr) for easy retrieval

  • save sparse waveforms or sparsify dense waveforms

  • select units and associated waveforms

The default format (mode='folder') which waveforms are saved to is a folder structure with waveforms as .npy files. In addition, waveforms can also be extracted in-memory for fast computations (mode='memory'). Note that this mode can quickly fill up your RAM… Use it wisely! Finally, an existing WaveformExtractor can be saved also in zarr format.

# extract dense waveforms on 500 spikes per unit
we = extract_waveforms(recording, sorting, folder="waveforms",
                       max_spikes_per_unit=500)
# same, but with parallel processing! (1s chunks processed by 8 jobs)
job_kwargs = dict(n_jobs=8, chunk_duration="1s")
we = extract_waveforms(recording, sorting, folder="waveforms_par",
                       max_spikes_per_unit=500, overwrite=True,
                       **job_kwargs)
# same, but in-memory
we_mem = extract_waveforms(recording, sorting, folder=None,
                           mode="memory", max_spikes_per_unit=500,
                           **job_kwargs)

# load pre-computed waveforms
we_loaded = load_waveforms(folder="waveforms")

# retrieve waveforms and templates for a unit
waveforms0 = we.get_waveforms(unit0)
template0 = we.get_template(unit0)

# compute template standard deviations (average is computed by default)
# (this can also be done within the 'extract_waveforms')
we.precompute_templates(modes=("std",))

# retrieve all template means and standard devs
template_means = we.get_all_templates(mode="average")
template_stds = we.get_all_templates(mode="std")

# save to Zarr
we_zarr = we.save(folder="waveforms.zarr", format="zarr")

# extract sparse waveforms (see Sparsity section)
# this will use 50 spike per unit to estimate the sparsity of 40um radius for each unit
we_sparse = extract_waveforms(recording, sorting, folder="waveforms_sparse",
                              max_spikes_per_unit=500, sparse=True,
                              method="radius", radius_um=40,
                              num_spikes_for_sparsity=50)

Event

The BaseEvent class serves as basis for all Event classes. It allows one to retrieve events and epochs (e.g. TTL pulses). Internally, events are represented as numpy arrays with a structured dtype. The structured dtype must contain the time field, which represent the event times in seconds. Other fields are optional.

Here we assume event is a BaseEvent object with events from two channels:

channel_ids = event.channel_ids
num_channels = event.get_num_channels()
# get structured dtype for the first channel
event_dtype = event.get_dtype(channel_ids[0])
print(event_dtype)
# >>> dtype([('time', '<f8'), ('duration', '<f8'), ('label', '<U100')])

# retrieve events (with structured dtype)
events = event.get_events(channel_id=channel_ids[0], segment_index=0)
# retrieve event times
event_times = event.get_event_times(channel_id=channel_ids[0], segment_index=0)
### NOTE ###
# 'segment_index' is required for multi-segment objects

Snippets

The BaseSnippets class serves as basis for all Snippets classes (currently only NumpySnippets and WaveClusSnippetsExtractor are implemented).

It represents unsorted waveform cutouts. Some acquisition systems, in fact, allow users to set a threshold and only record the times at which a peak was detected and the waveform cut out around the peak.

NOTE: while we support this class (mainly for legacy formats), this approach is a bad practice and highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform template matching to recover spikes!

Here we assume snippets is a BaseSnippets object with 16 channels:

channel_ids = snippets.channel_ids
num_channels = snippets.get_num_channels()
# retrieve number of snippets
num_snippets = snippets.get_num_snippets(segment_index=0)
### NOTE ###
# 'segment_index' is required for multi-segment objects
# retrieve total number of snippets across segments
total_snippets = snippets.get_total_snippets()

# retrieve snippet size
nbefore = snippets.nbefore # samples before peak
nsamples_per_snippet = snippets.snippet_len # total
nafter = nsamples_per_snippet - nbefore # samples after peak

# retrieve sample/frame indices
frames = snippets.get_frames(segment_index=0)
# retrieve snippet cutouts
snippet_cutouts = snippets.get_snippets(segment_index=0)
# retrieve snippet cutouts on first 4 channels
snippet_cutouts_slice = snippets.get_snippets(channel_ids=channel_ids[:4],
                                              segment_index=0)

Handling probes

In order to handle probe information, SpikeInterface relies on the probeinterface package. Either a Probe or a ProbeGroup object can be attached to a recording and it loads probe information (particularly channel locations and sometimes groups). ProbeInterface also has a library of available probes, so that you can download and attach an existing probe to a recording with a few lines of code. When a probe is attached to a recording, the location property is automatically set. In addition, the contact_vector property will carry detailed information of the probe design.

Here we assume that recording has 64 channels and it has been recorded by a ASSY-156-P-1 probe from Cambridge Neurotech and wired via an Intan RHD2164 chip to the acquisition device. The probe has 4 shanks, which can be loaded as separate groups (and spike sorted separately):

import probeinterface as pi

# download probe
probe = pi.get_probe(manufacturer='cambridgeneurotech', probe_name='ASSY-156-P-1')
# add wiring
probe.wiring_to_device('ASSY-156>RHD2164')

# set probe
recording_w_probe = recording.set_probe(probe)
# set probe with group info
recording_w_probe = recording.set_probe(probe, group_mode="by_shank")
# set probe in place
recording.set_probe(probe, group_mode="by_shank", in_place=True)

# retrieve probe
probe_from_recording = recording.get_probe()
# retrieve channel locations
locations = recording.get_channel_locations()
# equivalent to recording.get_property("location")

Probe information is automatically propagated in SpikeInterface, for example when slicing a recording by channels or applying preprocessing.

Note that several read_*** functions in the extractors module automatically load the probe from the files (including, SpikeGLX, Open Ephys - only NPIX plugin, Maxwell, Biocam, and MEArec).

Sparsity

In several cases, it is not necessary to have waveforms on all channels. This is especially true for high-density probes, such as Neuropixels, because the waveforms of a unit will only appear on a small set of channels. Sparsity is defined as the subset of channels on which waveforms (and related information) are defined. Of course, sparsity is not global, but it is unit-specific.

Sparsity can be computed from a WaveformExtractor object with the compute_sparsity() function:

sparsity = compute_sparsity(we, method="radius", radius_um=40)

The returned sparsity is a ChannelSparsity object, which has convenient methods to access the sparsity information in several ways:

  • sparsity.unit_id_to_channel_ids returns a dictionary with unit ids as keys and the list of associated
    channel_ids as values
  • sparsity.unit_id_to_channel_indices returns a similar dictionary, but instead with channel indices as
    values (which can be used to slice arrays)

There are several methods to compute sparsity, including:

  • method="radius": selects the channels based on the channel locations. For example, using a
    radius_um=40, will select, for each unit, the channels which are whithin 40um of the channel with the
    largest amplitude (extremum channel). This is the recommended method for high-density probes
  • method="best_channels": selects the best num_channels channels based on their amplitudes. Note that
    in this case the selected channels might not be close to each other.
  • method="threshold": selects channels based on an SNR threshold (threshold argument)
  • method="by_property": selects channels based on a property, such as group. This method is recommended
    when working with tetrodes.

The computed sparsity can be used in several postprocessing and visualization functions. In addition, a “dense” WaveformExtractor can be saved as “sparse” as follows:

we_sparse = we.save(we, sparsity=sparsity, folder="waveforms_sparse")

The we_sparse object will now have an associated sparsity (we.sparsity), which is automatically taken into consideration for downstream analysis (with the is_sparse() method). Importantly, saving sparse waveforms, especially for high-density probes, dramatically reduces the size of the waveforms folder.

Saving, loading, and compression

The Base SpikeInterface objects (BaseRecording, BaseSorting, and BaseSnippets) hold full information about their history to maintain provenance. Each object is in fact internally represented as a dictionary (si_object.to_dict()) which can be used to re-instantiate the object from scratch (this is true for all objects except in-memory ones, see Object “in-memory”).

The save() function allows to easily store SI objects to a folder on disk. BaseRecording objects are stored in binary (.raw) or Zarr (.zarr) format and BaseSorting and BaseSnippets object in numpy (.npz) format. With the actual data, the save() function also stores the provenance dictionary and all the properties and annotations associated to the object. The save function also supports parallel processing to speed up the writing process.

From a SpikeInterface folder, the saved object can be reloaded with the load_extractor() function. This saving/loading features enables to store SpikeInterface objects efficiently and to distribute processing.

job_kwargs = dict(n_jobs=8, chunk_duration="1s")
# save recording to folder in binary (default) format
recording_bin = recording.save(folder="recording", **job_kwargs)
# save recording to folder in zarr format (.zarr is appended automatically)
recording_zarr = recording.save(folder="recording", format="zarr", **job_kwargs)
# save snippets to NPZ
snippets_saved = snippets.save(folder="snippets")
# save sorting to NPZ
sorting_saved = sorting.save(folder="sorting")

NOTE: the Zarr format by default applies data compression with Blosc.Zstandard codec with BIT shuffling. Any other Zarr-compatible compressor and filters can be applied using the compressor and filters arguments. For example, in this case we apply LZMA and use a Delta filter:

from numcodecs import LZMA, Delta

compressor = LZMA()
filters = [Delta(dtype="int16")]

recording_custom_comp = recording.save(folder="recording", format="zarr",
                                       compressor=compressor, filters=filters,
                                       **job_kwargs)

Parallel processing and job_kwargs

The core module also contains the basic tools used throughout SpikeInterface for parallel processing of recordings. In general, parallelization is achieved by splitting the recording in many small time chunks and process them in parallel (for more details, see the ChunkRecordingExecutor class).

Many functions support parallel processing (e.g., extract_waveforms(), save, and many more). All of this functions, in addition to other arguments, also accept the so-called job_kwargs. These are a set of keyword arguments which are common to all functions that support parallelization:

  • chunk_duration or chunk_size or chunk_memory or total_memory
    • chunk_size: int

      Number of samples per chunk

    • chunk_memory: str

      Memory usage for each job (e.g. ‘100M’, ‘1G’)

    • total_memory: str

      Total memory usage (e.g. ‘500M’, ‘2G’)

    • chunk_durationstr or float or None

      Chunk duration in s if float or with units if str (e.g. ‘1s’, ‘500ms’)

  • n_jobs: int

    Number of jobs to use. With -1 the number of jobs is the same as number of cores. A float like 0.5 means half of the availables core.

  • progress_bar: bool

    If True, a progress bar is printed

  • mp_context: str or None

    Context for multiprocessing. It can be None (default), “fork” or “spawn”. Note that “fork” is only available on UNIX systems

The default job_kwargs are n_jobs=1, chunk_duration="1s", progress_bar=True.

Any of these argument, can be overridden by manually passing the argument to a function (e.g., extract_waveforms(..., n_jobs=16)). Alternatively, job_kwargs can be set globally (for each SpikeInterface session), with the set_global_job_kwargs() function:

global_job_kwargs = dict(n_jobs=16, chunk_duration="5s", progress_bar=False)
set_global_job_kwargs(**global_job_kwargs)
print(get_global_job_kwargs())
# >>> {'n_jobs': 16, 'chunk_duration': '5s', 'progress_bar': False}

Object “in-memory”

While most of the times SpikeInterface objects will be loaded from a file, sometimes it is convenient to construct in-memory objects (for example, for testing a new method) or “manually” add some information to the pipeline workflow.

In order to do this, one can use the Numpy* classes, NumpyRecording, NumpySorting, NumpyEvent, and NumpySnippets. These object behave exactly like normal SpikeInterface objects, but they are not bound to a file. This makes these objects not dumpable, so parallel processing is not supported. In order to make them dumpable, one can simply save() them (see Saving, loading, and compression).

In this example, we create a recording and a sorting object from numpy objects:

import numpy as np

# in-memory recording
sampling_frequency = 30_000.
duration = 10.
num_samples = int(duration * sampling_frequency)
num_channels = 16
random_traces = np.random.randn(num_samples, num_channels)

recording_memory = NumpyRecording(traces_list=[random_traces])
# with more elements in `traces_list` we can make multi-segment objects

# in-memory sorting
num_units = 10
num_spikes_unit = 1000
spike_trains = []
labels = []
for i in range(num_units):
    spike_trains_i = np.random.randint(low=0, high=num_samples, size=num_spikes_unit)
    labels_i = [i] * num_spikes_unit
    spike_trains += spike_trains_i
    labels += labels_i

sorting_memory = NumpySorting.from_times_labels(times=spike_trains, labels=labels,
                                                sampling_frequency=sampling_frequency)

Manipulating objects: slicing, aggregating

BaseRecording (and BaseSnippets) and BaseSorting objects can be sliced in the time or channel/unit axis.

This operations are completely lazy, as there is no data duplication. After slicing or aggregating, the new objects will be a view of the original ones.

# here we load a very long recording and sorting
recording = read_spikeglx('np_folder')
sorting =read_kilosrt('ks_folder')

# keep one channel every ten channels
keep_ids = rec.channel_ids[::10]
sub_recording = rec.channel_slice(channel_ids=keep_ids)

# keep between 5min and 12min
fs = recording.sampling_frequency
sub_recording = recording.frame_slice(start_frame=int(fs * 60 * 5), end_frame=int(fs * 60 * 12))
sub_sorting = sorting.frame_slice(start_frame=int(fs * 60 * 5), end_frame=int(fs * 60 * 12))

# keep only the first 4 units
sub_sorting = sorting.select_units(unit_ids=sorting.unit_ids[:4])

We can also aggregate (or stack) multiple recordings on the channel axis using the aggregate_channels(). Note that for this operation the recordings need to have the same sampling frequency, number of segments, and number of samples:

recA_4_chans = read_binray('fileA.raw')
recB_4_chans = read_binray('fileB.raw')
rec_8_chans = aggregate_channels([recA_4_chans, recB_4_chans])

We can also aggregate (or stack) multiple sortings on the unit axis using the aggregate_units() function:

sortingA = read_npz('sortingA.npz')
sortingB = read_npz('sortingB.npz')
sorting_20_units = aggregate_units([sortingA, sortingB])

Working with multiple segments

Multi-segment objects can result from running different recording phases (e.g., baseline, stimulation, post-stimulation) without moving the underlying probe (e.g., just clicking play/pause on the acquisition software). Therefore, multiple segments are assumed to record from the same set of neurons.

We have several functions to manipulate segments of SpikeInterface objects. All these manipulations are lazy.

# recording2: recording with 2 segments
# recording3: recording with 3 segments

# `append_recordings` will append all segments of multiple recordings
recording5 = append_recordings([recording2, recording3])
# `recording5` will have 5 segments

# `concatenate_recordings` will make a mono-segment recording by virtual concatenation
recording_mono = concatenate_recordings([recording2, recording5])

# `split_recording` will return a list of mono-segment recordings out of a multi-segment one
recording_mono_list = split_recording(recording5)
# `recording_mono_list` will have 5 elements with 1 segment

# `select_segment_recording` will return a user-defined subset of segments
recording_select1 = select_segment_recording(recording5, segment_indices=3)
# `recording_select1` will have 1 segment (the 4th one)
recording_select2 = select_segment_recording(recording5, segment_indices=[0, 4])
# `recording_select2` will have 2 segments (the 1st and last one)

The same functions, except for the concatenate_* one, are also available for BaseSorting objects (append_sortings(), split_sorting(), select_segment_sorting()).

Note append_recordings() and:py:func:~spikeinterface.core.concatenate_recordings have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the multi segments concept, the other one is breaking it! See this example for more detail Append and/or concatenate segments.

Recording tools

The spikeinterface.core.recording_tools submodule offers some utility functions on top of the recording object:

Template tools

The spikeinterface.core.template_tools submodule includes functionalities on top of the WaveformExtractor object to retrieve important information about the templates:

Generate toy objects

The core module also offers some functions to generate toy/fake data. They are useful to make examples, tests, and small demos:

# recording with 2 segments and 4 channels
recording = generate_recording(generate_recording(num_channels=4, sampling_frequency=30000.,
                               durations=[10.325, 3.5], set_probe=True)

# sorting with 2 segments and 5 units
sorting = generate_sorting(num_units=5, sampling_frequency=30000., durations=[10.325, 3.5],
                           firing_rate=15, refractory_period=1.5)

# snippets of 60 samples on 2 channels from 5 units
snippets = generate_snippets(nbefore=20, nafter=40, num_channels=2,
                             sampling_frequency=30000., durations=[10.325, 3.5],
                             set_probe=True,  num_units=5)

There are also some more advanced functions to generate sorting objects with varioues “mistakes” (mainly for testing purposes):

  • synthesize_random_firings()

  • clean_refractory_period()

  • inject_some_duplicate_units()

  • inject_some_split_units()

  • synthetize_spike_train_bad_isi()

Downloading test datasets

The NEO package is maintaining a collection a files of many electrophysiology file formats: https://gin.g-node.org/NeuralEnsemble/ephy_testing_data

The download_dataset() function is capable of downloading and caching locally dataset from this repository. The function depends on the datalad python package, which internally depends on git and git-annex.

The download_dataset() is very useful to perform local tests on small files from various formats:

# Spike" format
local_file_path = download_dataset(remote_path='spike2/130322-1LY.smr')
rec = read_spike2(local_file_path)

# MEArec format
local_file_path = download_dataset(remote_path='mearec/mearec_test_10s.h5')
rec, sorting = read_mearec(local_file_path)

# SpikeGLX format
local_folder_path = download_dataset(remote_path='/spikeglx/multi_trigger_multi_gate')
rec = read_spikeglx(local_folder_path)