Build a RecordingExtractorΒΆ

Building a new RecordingExtractor for a specific file format is as simple as creating a new subclass based on the predefined base classes provided in the spikeextractors package.

To enable standardization among subclasses, the RecordingExtractors is an abstract base class which require a new subclass to override all methods which are decorated with @abstractmethod. The RecordingExtractors class has four abstract methods: get_channel_ids(), get_num_frames(), get_sampling_frequency(), and get_traces(). So all you need to do is create a class that inherits from RecordingExtractor and implements these four methods.

Along with these four methods, you can also optionally override the write_recording() function which enables any RecordingExtractor to be written into your format. Also, if you have an implementation of get_snippets() that is more efficient that the original implementation, you can optionally override that as well.

Any other methods, such as set_channel_locations() or get_epoch(), should not be overwritten as they are generic functions that any RecordingExtractor has access to upon initialization.

Finally, if your file format contains information about the channels (e.g. location, group, etc.), you are suggested to add that as a channel property upon initialization (this is optional).

An example of a RecordingExtractor that adds channel locations is shown here.

The contributed extractors are in the spikeextractors/extractors folder. You can fork the repo and create a new folder myformatextractors there. In the folder, create a new file named

from spikeextractors import RecordingExtractor
from spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args

    import mypackage
except ImportError:

class MyFormatRecordingExtractor(RecordingExtractor):
    Description of your recording extractor

    file_path: str or Path
        Path to myformat file
    extra_parameter: (type)
        What extra_parameter does
    extractor_name = 'MyFormatRecording'
    has_default_locations = False  # set to True if extractor has default locations
    has_unscaled = False  # set to True if traces can be returned in raw format (e.g. uint16/int16)
    installed = HAVE_MYPACKAGE  # check at class level if installed or not
    is_writable = True  # set to True if extractor implements `write_recording()` function
    mode = 'file'  # 'file' if input is 'file_path', 'folder' if input 'folder_path', 'file_or_folder' if input is 'file_or_folder_path'
    installation_mesg = "To use the MyFormatRecordingExtractor install mypackage: \n\n pip install mypackage\n\n"

    def __init__(self, file_path, extra_parameter):
        # check if installed
        assert self.installed, self.installation_mesg

        # instantiate base RecordingExtractor

        ## All file specific initialization code can go here.

        # Important pieces of information include (if available): channel locations, groups, gains, and offsets
        # To set these, one can use:
        # If the recording has default locations, they can be set as follows:
        self.set_channel_locations(locations)  # locations is a np.array (num_channels x 2)
        # If the recording has intrinsic channel groups, they can be set as follows:
        self.set_channel_groups(groups)  # groups is a list or a np.array with length num_channels
        # If the recording has unscaled traces, gains and offsets can be set as follows:
        self.set_channel_gains(gains)  # gains is a list or a np.array with length num_channels
        self.set_channel_offsets(gains)  # offsets is a list or a np.array with length num_channels
        # If the recording has times in seconds that are not regularly sampled (e.g. missing frames)
        # times in seconds can be set as follows:
        self.set_times(times) #

        ### IMPORTANT ###
        # gains and offsets are used to automatically convert raw data to uV (float) in the following way:
        # traces_uV = traces_raw * gains - offsets

    def get_channel_ids(self):

        # Fill code to get a list of channel_ids. If channel ids are not specified, you can use:
        # channel_ids = range(num_channels)

        return channel_ids

    def get_num_frames(self):

        # Fill code to get the number of frames (samples) in the recordings.

        return num_frames

    def get_sampling_frequency(self, unit_id, start_frame=None, end_frame=None):

        # Fill code to get the sampling frequency of the recordings.

        return sampling_frequency

    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
        '''This function extracts and returns a trace from the recorded data from the
        given channels ids and the given start and end frame. It will return
        traces from within three ranges:

            [start_frame, t_start+1, ..., end_frame-1]
            [start_frame, start_frame+1, ..., final_recording_frame - 1]
            [0, 1, ..., end_frame-1]
            [0, 1, ..., final_recording_frame - 1]

        if both start_frame and end_frame are given, if only start_frame is
        given, if only end_frame is given, or if neither start_frame or end_frame
        are given, respectively. Traces are returned in a 2D array that
        contains all of the traces from each channel with dimensions
        (num_channels x num_frames). In this implementation, start_frame is inclusive
        and end_frame is exclusive conforming to numpy standards.

        start_frame: int
            The starting frame of the trace to be returned (inclusive).
        end_frame: int
            The ending frame of the trace to be returned (exclusive).
        channel_ids: array_like
            A list or 1D array of channel ids (ints) from which each trace will be
        return_scaled: bool
            If True, traces are returned after scaling (using gain/offset). If False, the raw traces are returned

        traces: numpy.ndarray
            A 2D array that contains all of the traces from each channel.
            Dimensions are: (num_channels x num_frames)

        # Fill code to get the the traces of the specified channel_ids, from start_frame to end_frame
        ### IMPORTANT ###
        # If raw traces are available (e.g. int16/uint16), this function should return the raw traces only!
        # If gains and offsets are set in the init, the conversion to float is done automatically (depending on the
        # return_scaled) argument.

        return traces

    # optional
    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):
        Returns an array with frames of TTL signals. To be implemented in sub-classes

        start_frame: int
            The starting frame of the ttl to be returned (inclusive)
        end_frame: int
            The ending frame of the ttl to be returned (exclusive)
        channel_id: int
            The TTL channel id

        ttl_frames: array-like
            Frames of TTL signal for the specified channel
        ttl_state: array-like
            State of the transition: 1 - rising, -1 - falling

        # Fill code to return ttl frames and states

        return ttl_frames, ttl_states

    . #Optional functions and pre-implemented functions that a new RecordingExtractor doesn't need to implement

    def write_recording(recording, save_path, other_params):
        This is an example of a function that is not abstract so it is optional if you want to override it.
        It allows other RecordingExtractor to use your new RecordingExtractor to convert their recorded data into
        your recording file format.

When you are done you should add your RecordingExtractor to the file. You can optionally write a test in the tests/ (this is easier if a write_recording function is implemented).

Finally, make a pull request to the spikeextractor repo, so we can review the code and merge it to the spikeextractors!