{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Waveform Extractor\n\nSpikeInterface provides an efficient mechanism to extract waveform snippets.\n\nThe :py:class:`~spikeinterface.core.WaveformExtractor` class:\n\n  * randomly samples a subset spikes with max_spikes_per_unit\n  * extracts all waveforms snippets for each unit\n  * saves waveforms in a local folder\n  * can load stored waveforms\n  * retrieves template (average or median waveform) for each unit\n\nHere the how!\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n\nfrom spikeinterface import download_dataset\nfrom spikeinterface import WaveformExtractor, extract_waveforms\nimport spikeinterface.extractors as se"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First let's use the repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data\nto download a MEArec dataset. It is a simulated dataset that contains \"ground truth\"\nsorting information:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'\nremote_path = 'mearec/mearec_test_10s.h5'\nlocal_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's now instantiate the recording and sorting objects:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "recording = se.MEArecRecordingExtractor(local_path)\nprint(recording)\nsorting = se.MEArecSortingExtractor(local_path)\nprint(recording)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The MEArec dataset already contains a probe object that you can retrieve\nan plot:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "probe = recording.get_probe()\nprint(probe)\nfrom probeinterface.plotting import plot_probe\n\nplot_probe(probe)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A :py:class:`~spikeinterface.core.WaveformExtractor` object can be created with the\n:py:func:`~spikeinterface.core.extract_waveforms` function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "folder = 'waveform_folder'\nwe = extract_waveforms(\n    recording,\n    sorting,\n    folder,\n    ms_before=1.5,\n    ms_after=2.,\n    max_spikes_per_unit=500,\n    load_if_exists=True,\n)\nprint(we)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Alternatively, the :py:class:`~spikeinterface.core.WaveformExtractor` object can be instantiated\ndirectly. In this case, we need to :py:func:`~spikeinterface.core.WaveformExtractor.set_params` to set the desired\nparameters:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "folder = 'waveform_folder2'\nwe = WaveformExtractor.create(recording, sorting, folder, remove_if_exists=True)\nwe.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=1000)\nwe.run_extract_waveforms(n_jobs=1, chunk_size=30000, progress_bar=True)\nprint(we)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The :code:`'waveform_folder'` folder contains:\n * the dumped recording (json)\n * the dumped sorting (json)\n * the parameters (json)\n * a subfolder with \"waveforms_XXX.npy\" and \"sampled_index_XXX.npy\"\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\n\nprint(os.listdir(folder))\nprint(os.listdir(folder + '/waveforms'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can retrieve waveforms per unit on-the-fly. The waveforms shape\nis (num_spikes, num_sample, num_channel):\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "unit_ids = sorting.unit_ids\n\nfor unit_id in unit_ids:\n    wfs = we.get_waveforms(unit_id)\n    print(unit_id, ':', wfs.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can also get the template for each units either using the median or the\naverage:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for unit_id in unit_ids[:3]:\n    fig, ax = plt.subplots()\n    template = we.get_template(unit_id=unit_id, mode='median')\n    print(template.shape)\n    ax.plot(template)\n    ax.set_title(f'{unit_id}')\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}