{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Ground truth study tutorial\n\nThis tutorial illustrates how to run a \"study\".\nA study is a systematic performance comparisons several ground truth\nrecordings with several sorters.\n\nThe submodule study and the class  propose high level tools functions\nto run many groundtruth comparison with many sorter on many recordings\nand then collect and aggregate results in an easy way.\n\nThe all mechanism is based on an intrinsic organization\ninto a \"study_folder\" with several subfolder:\n\n* raw_files : contain a copy in binary format of recordings\n* sorter_folders : contains output of sorters\n* ground_truth : contains a copy of sorting ground  in npz format\n* sortings: contains light copy of all sorting in npz format\n* tables: some table in cvs format\n\nIn order to run and rerun the computation all gt_sorting and\nrecordings are copied to a fast and universal format :\nbinary (for recordings) and npz (for sortings).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Imports\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nimport seaborn as sns\n\nimport spikeinterface.extractors as se\nimport spikeinterface.widgets as sw\nfrom spikeinterface.comparison import GroundTruthStudy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup study folder and run all sorters\n\nWe first generate the folder.\nthis can take some time because recordings are copied inside the folder.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1)\nrec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1)\ngt_dict = {\n    'rec0': (rec0, gt_sorting0),\n    'rec1': (rec1, gt_sorting1),\n}\nstudy_folder = 'a_study_folder'\nstudy = GroundTruthStudy.create(study_folder, gt_dict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then just run all sorters on all recordings in one functions.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# sorter_list = st.sorters.available_sorters() # this get all sorters.\nsorter_list = ['herdingspikes', 'tridesclous', ]\nstudy.run_sorters(sorter_list, mode_if_folder_exists=\"keep\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "You can re run **run_study_sorters** as many time as you want.\nBy default **mode='keep'** so only uncomputed sorters are rerun.\nFor instance, so just remove the \"sorter_folders/rec1/herdingspikes\" to re-run\nonly one sorter on one recording.\n\nThen we copy the spike sorting outputs into a separate subfolder.\nThis allow to remove the \"large\" sorter_folders.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "study.copy_sortings()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Collect comparisons\n\u00a0\nYou can collect in one shot all results and run the\nGroundTruthComparison on it.\nSo you can access finely to all individual results.\n\u00a0\nNote that exhaustive_gt=True when you exactly how many\nunits in ground truth (for synthetic datasets)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "study.run_comparisons(exhaustive_gt=True)\n\nfor (rec_name, sorter_name), comp in study.comparisons.items():\n    print('*' * 10)\n    print(rec_name, sorter_name)\n    print(comp.count_score)  # raw counting of tp/fp/...\n    comp.print_summary()\n    perf_unit = comp.get_performance(method='by_unit')\n    perf_avg = comp.get_performance(method='pooled_with_average')\n    m = comp.get_confusion_matrix()\n    w_comp = sw.plot_agreement_matrix(comp)\n    w_comp.ax.set_title(rec_name  + ' - ' + sorter_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Collect synthetic dataframes and display\n\nAs shown previously, the performance is returned as a pandas dataframe.\nThe :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function, gathers all the outputs in\nthe study folder and merges them in a single dataframe.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataframes = study.aggregate_dataframes()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Pandas dataframes can be nicely displayed as tables in the notebook.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(dataframes.keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(dataframes['run_times'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Easy plot with seaborn\n\u00a0\nSeaborn allows to easily plot pandas dataframes. Let\u2019s see some\nexamples.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "run_times = dataframes['run_times']\nfig1, ax1 = plt.subplots()\nsns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1)\nax1.set_title('Run times')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "perfs = dataframes['perf_by_unit']\nfig2, ax2 = plt.subplots()\nsns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2)\nax2.set_title('Recall')\nax2.set_ylim(-0.1, 1.1)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}