{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Curation Tutorial\n\nAfter spike sorting and computing quality metrics, you can automatically curate the spike sorting output using the\nquality metrics.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import spikeinterface as si\nimport spikeinterface.extractors as se\nfrom spikeinterface.postprocessing import compute_principal_components\nfrom spikeinterface.qualitymetrics import compute_quality_metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's download a simulated dataset\n from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'\n\nLet's imagine that the ground-truth sorting is in fact the output of a sorter.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')\nrecording, sorting = se.read_mearec(local_path)\nprint(recording)\nprint(sorting)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, we extract waveforms and compute their PC scores:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "folder = 'waveforms_mearec'\nwe = si.extract_waveforms(recording, sorting, folder,\n                          load_if_exists=True,\n                          ms_before=1, ms_after=2., max_spikes_per_unit=500,\n                          n_jobs=1, chunk_size=30000)\nprint(we)\n\npc = compute_principal_components(we, load_if_exists=True,\n                                     n_components=3, mode='by_channel_local')\nprint(pc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then we compute some quality metrics:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "metrics = compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'nearest_neighbor'])\nprint(metrics)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can now threshold each quality metric and select units based on some rules.\n\nThe easiest and most intuitive way is to use boolean masking with dataframe:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_rate'] < 0.05) & (metrics['nn_hit_rate'] > 0.90)\nprint(keep_mask)\n\nkeep_unit_ids = keep_mask[keep_mask].index.values\nprint(keep_unit_ids)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And now let's create a sorting that contains only curated units and save it,\nfor example to an NPZ file.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "curated_sorting = sorting.select_units(keep_unit_ids)\nprint(curated_sorting)\nse.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}