{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Preprocessing Tutorial\n\nBefore spike sorting, you may need to preproccess your signals in order to improve the spike sorting performance.\nYou can do that in SpikeInterface using the :py:mod:`spikeinterface.preprocessing` submodule.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport matplotlib.pylab as plt\nimport scipy.signal\n\nimport spikeinterface.extractors as se\nfrom spikeinterface.preprocessing import (bandpass_filter, notch_filter, common_reference,\n                                          remove_artifacts, preprocesser_dict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's create a toy example:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "recording, sorting = se.toy_example(num_channels=4, duration=10, seed=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Apply filters\n\u00a0\nNow apply a bandpass filter and a notch filter (separately) to the\nrecording extractor. Filters are also :py:class:`~spikeinterface.core.BaseRecording` objects.\nNote that these operation are **lazy** the computation is done on the fly\nwith `rec.get_traces()`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "recording_bp = bandpass_filter(recording, freq_min=300, freq_max=6000)\nprint(recording_bp)\nrecording_notch = notch_filter(recording, freq=2000, q=30)\nprint(recording_notch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now let's plot the power spectrum of non-filtered, bandpass filtered,\nand notch filtered recordings.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fs = recording.get_sampling_frequency()\n\nf_raw, p_raw = scipy.signal.welch(recording.get_traces(segment_index=0)[:, 0], fs=fs)\nf_bp, p_bp = scipy.signal.welch(recording_bp.get_traces(segment_index=0)[:, 0], fs=fs)\nf_notch, p_notch = scipy.signal.welch(recording_notch.get_traces(segment_index=0)[:, 0], fs=fs)\n\nfig, ax = plt.subplots()\nax.semilogy(f_raw, p_raw, f_bp, p_bp, f_notch, p_notch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Change reference\n\nIn many cases, before spike sorting, it is wise to re-reference the\nsignals to reduce the common-mode noise from the recordings.\n\nTo re-reference in :py:mod:`spikeinterface.preprocessing` you can use the\n:py:func:`~spikeinterface.preprocessing.common_reference`\nfunction. Both common average reference (CAR) and common median\nreference (CMR) can be applied. Moreover, the average/median can be\ncomputed on different groups. Single channels can also be used as\nreference.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "recording_car = common_reference(recording, reference='global', operator='average')\nrecording_cmr = common_reference(recording, reference='global', operator='median')\nrecording_single = common_reference(recording, reference='single', ref_channel_ids=[1])\nrecording_single_groups = common_reference(recording, reference='single',\n                                              groups=[[0, 1], [2, 3]], \n                                              ref_channel_ids=[0, 2])\n\n\ntrace0_car = recording_car.get_traces(segment_index=0)[:, 0]\ntrace0_cmr = recording_cmr.get_traces(segment_index=0)[:, 0]\ntrace0_single = recording_single.get_traces(segment_index=0)[:, 0]\nfig1, ax1 = plt.subplots()\nax1.plot(trace0_car)\nax1.plot(trace0_cmr)\nax1.plot(trace0_single)\n\ntrace1_groups = recording_single_groups.get_traces(segment_index=0)[:, 1]\ntrace0_groups = recording_single_groups.get_traces(segment_index=0)[:, 0]\nfig2, ax2 = plt.subplots()\nax2.plot(trace1_groups)  # not zero\nax2.plot(trace0_groups)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Remove stimulation artifacts\n\u00a0\nIn some applications, electrodes are used to electrically stimulate the\ntissue, generating a large artifact. In :py:mod:`spikeinterface.preprocessing`, the artifact\ncan be zeroed-out using the :py:func:`~spikeinterface.preprocessing.remove_artifacts` function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# create dummy stimulation triggers per segment\nstimulation_trigger_frames = [\n        [10000, 150000, 200000],\n        [20000, 30000],\n    ]\n\n\n# large ms_before and s_after are used for plotting only\nrecording_rm_artifact = remove_artifacts(recording, stimulation_trigger_frames,\n                                                         ms_before=100, ms_after=200)\n\ntrace0 = recording.get_traces(segment_index=0)[:, 0]\ntrace0_rm = recording_rm_artifact.get_traces(segment_index=0)[:, 0]\nfig3, ax3 = plt.subplots()\nax3.plot(trace0)\nax3.plot(trace0_rm)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "You can list the available preprocessors with:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pprint import pprint\npprint(preprocesser_dict)\n\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}