{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Compare two sorters\n\nThis example show how to compare the result of two sorters.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Import\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport matplotlib.pyplot as plt\n\nimport spikeinterface as si\nimport spikeinterface.extractors as se\nimport spikeinterface.sorters as ss\nimport spikeinterface.comparison as sc\nimport spikeinterface.widgets as sw"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's download a simulated dataset\n from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')\nrecording, sorting = se.read_mearec(local_path)\nprint(recording)\nprint(sorting)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then run two spike sorters and compare their output.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sorting_HS = ss.run_herdingspikes(recording)\nsorting_TDC = ss.run_tridesclous(recording)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The :py:func:`~spikeinterface.comparison.compare_two_sorters` function allows us to compare the spike\nsorting output. It returns a :py:class:`~spikeinterface.comparison.SymmetricSortingComparison` object, with methods\nto inspect the comparison output easily. The comparison matches the\nunits by comparing the agreement between unit spike trains.\n\nLet\u2019s see how to inspect and access this matching.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cmp_HS_TDC = sc.compare_two_sorters(\n    sorting1=sorting_HS,\n    sorting2=sorting_TDC,\n    sorting1_name='HS',\n    sorting2_name='TDC',\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can check the agreement matrix to inspect the matching.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sw.plot_agreement_matrix(cmp_HS_TDC)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Some useful internal dataframes help to check the match and count\n\u00a0 like **match_event_count** or **agreement_scores**\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(cmp_HS_TDC.match_event_count)\nprint(cmp_HS_TDC.agreement_scores)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to check which units were matched, the :code:`get_matching`\n\u00a0methods can be used. If units are not matched they are listed as -1.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sc_to_tdc, tdc_to_sc = cmp_HS_TDC.get_matching()\n\nprint('matching HS to TDC')\nprint(sc_to_tdc)\nprint('matching TDC to HS')\nprint(tdc_to_sc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The :py:func:`~spikeinterface.core.BaseSortingSegment.get_unit_spike_train` returns the mapped spike train. We\ncan use\nit to check the spike times.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "matched_ids = sc_to_tdc[sc_to_tdc != -1]\n\nunit_id_HS = matched_ids.index[0]\nunit_id_TDC = matched_ids[unit_id_HS]\n\n\n\n# check that matched spike trains correspond\nst1 = sorting_HS.get_unit_spike_train(unit_id_HS)\nst2 = sorting_TDC.get_unit_spike_train(unit_id_TDC)\nfig, ax = plt.subplots()\nax.plot(st1, np.zeros(st1.size), '|')\nax.plot(st2, np.ones(st2.size), '|')\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}