{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Quality Metrics Tutorial\n\nAfter spike sorting, you might want to validate the goodness of the sorted units. This can be done using the\n:code:`qualitymetrics` submodule, which computes several quality metrics of the sorted units.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import spikeinterface as si\nimport spikeinterface.extractors as se\nfrom spikeinterface.postprocessing import compute_principal_components\nfrom spikeinterface.qualitymetrics import (compute_snrs, compute_firing_rate, \n    compute_isi_violations, calculate_pc_metrics, compute_quality_metrics)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's download a simulated dataset\nfrom the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')\nrecording, sorting = se.read_mearec(local_path)\nprint(recording)\nprint(sorting)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Extract spike waveforms\n\nFor convenience, metrics are computed on the :code:`WaveformExtractor` object,\nbecause it contains a reference to the \"Recording\" and the \"Sorting\" objects:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "folder = 'waveforms_mearec'\nwe = si.extract_waveforms(recording, sorting, folder,\n                          load_if_exists=True,\n                          ms_before=1, ms_after=2., max_spikes_per_unit=500,\n                          n_jobs=1, chunk_size=30000)\nprint(we)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The :code:`spikeinterface.qualitymetrics` submodule has a set of functions that allow users to compute\nmetrics in a compact and easy way. To compute a single metric, one can simply run one of the\nquality metric functions as shown below. Each function has a variety of adjustable parameters that can be tuned.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "firing_rates = compute_firing_rate(we)\nprint(firing_rates)\nisi_violation_ratio, isi_violations_rate, isi_violations_count = compute_isi_violations(we)\nprint(isi_violation_ratio)\nsnrs = compute_snrs(we)\nprint(snrs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Some metrics are based on the principal component scores, so they require a\n:code:`WaveformsPrincipalComponent` object as input:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pc = compute_principal_components(we, load_if_exists=True,\n                                     n_components=3, mode='by_channel_local')\nprint(pc)\n\npc_metrics = calculate_pc_metrics(pc, metric_names=['nearest_neighbor'])\nprint(pc_metrics)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To compute more than one metric at once, we can use the :code:`compute_quality_metrics` function and indicate\nwhich metrics we want to compute. This will return a pandas dataframe:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "metrics = compute_quality_metrics(we)\nprint(metrics)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}