{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Postprocessing Tutorial\n\nSpike sorters generally output a set of units with corresponding spike trains. The :code:`postprocessing`\nsubmodule allows to combine the :code:`RecordingExtractor` and the sorted :code:`SortingExtractor` objects to perform\nfurther postprocessing.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pylab as plt\n\nimport spikeinterface as si\nimport spikeinterface.extractors as se\nfrom spikeinterface.postprocessing import get_template_extremum_channel, compute_principal_components"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's download a simulated dataset\nfrom the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')\nrecording, sorting = se.read_mearec(local_path)\nprint(recording)\nprint(sorting)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Assuming the :code:`sorting` is the output of a spike sorter, the\n:code:`postprocessing` module allows to extract all relevant information\nfrom the paired recording-sorting.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Compute spike waveforms\n\nWaveforms are extracted with the WaveformExtractor or directly with the\n:code:`extract_waveforms` function (which returns a\n:code:`WaveformExtractor` object):\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "folder = 'waveforms_mearec'\nwe = si.extract_waveforms(recording, sorting, folder,\n                          load_if_exists=True,\n                          ms_before=1, ms_after=2., max_spikes_per_unit=500,\n                          n_jobs=1, chunk_size=30000)\nprint(we)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's plot the waveforms of units [0, 1, 2] on channel 8\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "colors = ['Olive', 'Teal', 'Fuchsia']\n\nfig, ax = plt.subplots()\nfor i, unit_id in enumerate(sorting.unit_ids[:3]):\n    wf = we.get_waveforms(unit_id)\n    color = colors[i]\n    ax.plot(wf[:, :, 8].T, color=color, lw=0.3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Compute unit templates\n\u00a0\nSimilarly to waveforms, templates - average waveforms - can be easily retrieved\nfrom the :code:`WaveformExtractor` object:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots()\nfor i, unit_id in enumerate(sorting.unit_ids[:3]):\n    template = we.get_template(unit_id)\n    color = colors[i]\n    ax.plot(template[:, 8].T, color=color, lw=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Compute unit maximum channel\n\u00a0\nIn a similar way, one can get the recording channel with the 'extremum' signal\n(minimum or maximum). The code:`get_template_extremum_channel` outputs a\ndictionary unit_ids as keys and channel_ids as values:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "extremum_channels_ids = get_template_extremum_channel(we, peak_sign='neg')\nprint(extremum_channels_ids)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Compute principal components (aka PCs)\n\u00a0\nComputing PCA scores for each waveforms is very common for many applications,\nincluding unsupervised validation of the spike sorting performance.\n\nThere are different ways to compute PC scores from waveforms:\n  * \"concatenated\": all waveforms are concatenated and a single PCA model is computed (channel information is lost)\n  * \"by_channel_global\": PCA is computed from a subset of all waveforms and applied independently on each channel\n  * \"by_channel_local\": PCA is computed and applied to each channel separately\n\nIn SI, we can compute PC scores with the :code:`compute_principal_components` function\n(which returns a :code:`WaveformPrincipalComponent` object).\nThe pc scores for a unit are retrieved with the :code:`get_projections` function and\nthe shape of the pc scores is (n_spikes, n_components, n_channels).\nHere, we compute PC scores and plot the first and second components of channel 8:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pc = compute_principal_components(we, load_if_exists=True,\n                                  n_components=3, mode='by_channel_local')\nprint(pc)\n\nfig, ax = plt.subplots()\nfor i, unit_id in enumerate(sorting.unit_ids[:3]):\n    comp = pc.get_projections(unit_id)\n    print(comp.shape)\n    color = colors[i]\n    ax.scatter(comp[:, 0, 8], comp[:, 1, 8], color=color)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that PC scores for all units can be retrieved at once with the\n`get_all_projections()` function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "all_labels, all_projections = pc.get_all_projections()\nprint(all_labels[:40])\nprint(all_labels.shape)\nprint(all_projections.shape)\n\ncmap = plt.get_cmap('Dark2', len(sorting.unit_ids))\n\nfig, ax = plt.subplots()\nfor i, unit_id in enumerate(sorting.unit_ids):\n    mask = all_labels == unit_id\n    comp = all_projections[mask, :, :]\n    ax.scatter(comp[:, 0, 8], comp[:, 1, 8], color=cmap(i))\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}