{ "cells": [ { "cell_type": "markdown", "id": "62360393-aec3-4229-a4d0-5bedbd24d8bb", "metadata": {}, "source": [ "# Benchmark engines" ] }, { "cell_type": "markdown", "id": "b692c9a7-0a6d-455c-8e83-5663d4208d3b", "metadata": {}, "source": [ "## Overview\n", "\n", "In this notebook, we will test the run time of all engines in TenCirChem, along with their interplay with different backends.\n", "\n", "The hydrogen chain system is used as the benchmark platform. The benchmarked system size is from 2 atoms to 6 atoms." ] }, { "cell_type": "markdown", "id": "8fdf019d-5d55-4598-a596-76d34f40f473", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "4f596c92-6fa5-4c69-b630-55c3a4d11345", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from tencirchem import UCCSD, set_backend\n", "from tencirchem.molecule import h_chain" ] }, { "cell_type": "code", "execution_count": 2, "id": "b9858e7d", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "n_h_list = list(range(2, 8, 2))\n", "uccsd_list = [UCCSD(h_chain(n_h)) for n_h in n_h_list]\n", "params_list = [np.random.rand(uccsd.n_params) for uccsd in uccsd_list]" ] }, { "cell_type": "code", "execution_count": 3, "id": "5273281b", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# tensornetwork and statevector engine are only compatible with JAX backend\n", "jax_engines = [\"tensornetwork\", \"statevector\", \"civector\", \"civector-large\", \"pyscf\"]\n", "numpy_engines = [\"civector\", \"civector-large\", \"pyscf\"]\n", "cupy_engines = numpy_engines\n", "tested_engines_list = [jax_engines, numpy_engines, cupy_engines]" ] }, { "cell_type": "markdown", "id": "b3291494-34c6-45e8-93b4-0dbf5f98009f", "metadata": {}, "source": [ "## Benchmark" ] }, { "cell_type": "code", "execution_count": 5, "id": "a19eeeee", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('jax', 4, 'tensornetwork', 0.6734888553619385, 0.0013010025024414063, 0.6995089054107666)\n", "('jax', 4, 'statevector', 0.5532455444335938, 0.0008256316184997558, 0.5697581768035889)\n", "('jax', 4, 'civector', 0.6092183589935303, 0.0061431884765625, 0.7320821285247803)\n", "('jax', 4, 'civector-large', 0.6665019989013672, 0.0064354419708251955, 0.7952108383178711)\n", "('jax', 4, 'pyscf', 0.0294039249420166, 0.016276955604553223, 0.35494303703308105)\n", "('jax', 8, 'tensornetwork', 4.530415773391724, 0.003479158878326416, 4.599998950958252)\n", "('jax', 8, 'statevector', 2.3404414653778076, 0.001799321174621582, 2.3764278888702393)\n", "('jax', 8, 'civector', 0.6547484397888184, 0.007375049591064453, 0.8022494316101074)\n", "('jax', 8, 'civector-large', 1.63521146774292, 0.006232154369354248, 1.7598545551300049)\n", "('jax', 8, 'pyscf', 0.1072854995727539, 0.09403668642044068, 1.9880192279815674)\n", "('jax', 12, 'tensornetwork', 24.58689785003662, 0.05817370414733887, 25.7503719329834)\n", "('jax', 12, 'statevector', 7.56268572807312, 0.005127060413360596, 7.665226936340332)\n", "('jax', 12, 'civector', 0.9048683643341064, 0.01238323450088501, 1.1525330543518066)\n", "('jax', 12, 'civector-large', 8.32372498512268, 0.00833052396774292, 8.490335464477539)\n", "('jax', 12, 'pyscf', 0.4856069087982178, 0.4437950611114502, 9.361508131027222)\n", "('numpy', 4, 'civector', 0.0034182071685791016, 0.003229069709777832, 0.06799960136413574)\n", "('numpy', 4, 'civector-large', 0.0037720203399658203, 0.003747880458831787, 0.07872962951660156)\n", "('numpy', 4, 'pyscf', 0.01477956771850586, 0.014781630039215088, 0.3104121685028076)\n", "('numpy', 8, 'civector', 0.005614757537841797, 0.003976082801818848, 0.08513641357421875)\n", "('numpy', 8, 'civector-large', 0.0069615840911865234, 0.006995594501495362, 0.14687347412109375)\n", "('numpy', 8, 'pyscf', 0.09161663055419922, 0.09203903675079346, 1.9323973655700684)\n", "('numpy', 12, 'civector', 0.01762080192565918, 0.008578836917877197, 0.18919754028320312)\n", "('numpy', 12, 'civector-large', 0.02767157554626465, 0.02590758800506592, 0.545823335647583)\n", "('numpy', 12, 'pyscf', 0.4722292423248291, 0.443195104598999, 9.33613133430481)\n", "('cupy', 4, 'civector', 0.32034993171691895, 0.008597338199615478, 0.4922966957092285)\n", "('cupy', 4, 'civector-large', 0.02430891990661621, 0.014339327812194824, 0.3110954761505127)\n", "('cupy', 4, 'pyscf', 0.01539301872253418, 0.01657602787017822, 0.34691357612609863)\n", "('cupy', 8, 'civector', 0.036718130111694336, 0.01722104549407959, 0.38113903999328613)\n", "('cupy', 8, 'civector-large', 0.05296659469604492, 0.052764499187469484, 1.1082565784454346)\n", "('cupy', 8, 'pyscf', 0.09645533561706543, 0.09423872232437133, 1.9812297821044922)\n", "('cupy', 12, 'civector', 0.1256849765777588, 0.04901479482650757, 1.1059808731079102)\n", "('cupy', 12, 'civector-large', 0.18027353286743164, 0.1753893733024597, 3.688060998916626)\n", "('cupy', 12, 'pyscf', 0.4615364074707031, 0.45049295425415037, 9.471395492553711)\n" ] } ], "source": [ "time_data = []\n", "for backend, tested_engines in zip([\"jax\", \"numpy\", \"cupy\"], tested_engines_list):\n", " set_backend(backend)\n", " for uccsd, params in zip(uccsd_list, params_list):\n", " for engine in tested_engines:\n", " # dry run first. Let it compile or build caches\n", " time1 = time.time()\n", " uccsd.energy_and_grad(params, engine=engine)\n", " time2 = time.time()\n", " staging_time = time2 - time1\n", " # several real runs. Assuming `n_run` evaluations during the optimization\n", " n_run = 20\n", " for i in range(n_run):\n", " uccsd.energy_and_grad(params, engine=engine)\n", " run_time = (time.time() - time2) / n_run\n", " item = (backend, uccsd.n_qubits, engine, staging_time, run_time, staging_time + n_run * run_time)\n", " print(item)\n", " time_data.append(item)" ] }, { "cell_type": "markdown", "id": "d580b88a-1be1-4d2f-8e70-23cd1dcf9700", "metadata": {}, "source": [ "## Results and Discussion" ] }, { "cell_type": "code", "execution_count": 6, "id": "9bde172c", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
staging timerun timetotal time
backendqubitsengine
jax4tensornetwork0.6734890.0013010.699509
statevector0.5532460.0008260.569758
civector0.6092180.0061430.732082
civector-large0.6665020.0064350.795211
pyscf0.0294040.0162770.354943
8tensornetwork4.5304160.0034794.599999
statevector2.3404410.0017992.376428
civector0.6547480.0073750.802249
civector-large1.6352110.0062321.759855
pyscf0.1072850.0940371.988019
12tensornetwork24.5868980.05817425.750372
statevector7.5626860.0051277.665227
civector0.9048680.0123831.152533
civector-large8.3237250.0083318.490335
pyscf0.4856070.4437959.361508
numpy4civector0.0034180.0032290.068000
civector-large0.0037720.0037480.078730
pyscf0.0147800.0147820.310412
8civector0.0056150.0039760.085136
civector-large0.0069620.0069960.146873
pyscf0.0916170.0920391.932397
12civector0.0176210.0085790.189198
civector-large0.0276720.0259080.545823
pyscf0.4722290.4431959.336131
cupy4civector0.3203500.0085970.492297
civector-large0.0243090.0143390.311095
pyscf0.0153930.0165760.346914
8civector0.0367180.0172210.381139
civector-large0.0529670.0527641.108257
pyscf0.0964550.0942391.981230
12civector0.1256850.0490151.105981
civector-large0.1802740.1753893.688061
pyscf0.4615360.4504939.471395
\n", "
" ], "text/plain": [ " staging time run time total time\n", "backend qubits engine \n", "jax 4 tensornetwork 0.673489 0.001301 0.699509\n", " statevector 0.553246 0.000826 0.569758\n", " civector 0.609218 0.006143 0.732082\n", " civector-large 0.666502 0.006435 0.795211\n", " pyscf 0.029404 0.016277 0.354943\n", " 8 tensornetwork 4.530416 0.003479 4.599999\n", " statevector 2.340441 0.001799 2.376428\n", " civector 0.654748 0.007375 0.802249\n", " civector-large 1.635211 0.006232 1.759855\n", " pyscf 0.107285 0.094037 1.988019\n", " 12 tensornetwork 24.586898 0.058174 25.750372\n", " statevector 7.562686 0.005127 7.665227\n", " civector 0.904868 0.012383 1.152533\n", " civector-large 8.323725 0.008331 8.490335\n", " pyscf 0.485607 0.443795 9.361508\n", "numpy 4 civector 0.003418 0.003229 0.068000\n", " civector-large 0.003772 0.003748 0.078730\n", " pyscf 0.014780 0.014782 0.310412\n", " 8 civector 0.005615 0.003976 0.085136\n", " civector-large 0.006962 0.006996 0.146873\n", " pyscf 0.091617 0.092039 1.932397\n", " 12 civector 0.017621 0.008579 0.189198\n", " civector-large 0.027672 0.025908 0.545823\n", " pyscf 0.472229 0.443195 9.336131\n", "cupy 4 civector 0.320350 0.008597 0.492297\n", " civector-large 0.024309 0.014339 0.311095\n", " pyscf 0.015393 0.016576 0.346914\n", " 8 civector 0.036718 0.017221 0.381139\n", " civector-large 0.052967 0.052764 1.108257\n", " pyscf 0.096455 0.094239 1.981230\n", " 12 civector 0.125685 0.049015 1.105981\n", " civector-large 0.180274 0.175389 3.688061\n", " pyscf 0.461536 0.450493 9.471395" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame(\n", " time_data, columns=[\"backend\", \"qubits\", \"engine\", \"staging time\", \"run time\", \"total time\"]\n", ").set_index([\"backend\", \"qubits\", \"engine\"])\n", "df" ] }, { "cell_type": "markdown", "id": "d0272350-9659-4eab-a8b2-f6c3ab413e6c", "metadata": {}, "source": [ "The table contains rich information, but conclusion is not easily drawn. \n", "\n", "Next, we find out the best option for each system size." ] }, { "cell_type": "code", "execution_count": 7, "id": "7f0cc1ff", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('jax', 4, 'statevector') ('numpy', 4, 'civector')\n", "('jax', 8, 'statevector') ('numpy', 8, 'civector')\n", "('jax', 12, 'statevector') ('numpy', 12, 'civector')\n" ] } ], "source": [ "interesting_indices = []\n", "for i, ddf in df.groupby(\"qubits\"):\n", " run_time_idx = ddf[\"run time\"].idxmin()\n", " total_time_idx = ddf[\"total time\"].idxmin()\n", " print(run_time_idx, total_time_idx)\n", " interesting_indices.extend([run_time_idx, total_time_idx])" ] }, { "cell_type": "code", "execution_count": 8, "id": "b467181e-1a9c-4e87-9f75-d2c52bf845e5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
staging timerun timetotal time
backendqubitsengine
jax4statevector0.5532460.0008260.569758
numpy4civector0.0034180.0032290.068000
jax8statevector2.3404410.0017992.376428
numpy8civector0.0056150.0039760.085136
jax12statevector7.5626860.0051277.665227
numpy12civector0.0176210.0085790.189198
\n", "
" ], "text/plain": [ " staging time run time total time\n", "backend qubits engine \n", "jax 4 statevector 0.553246 0.000826 0.569758\n", "numpy 4 civector 0.003418 0.003229 0.068000\n", "jax 8 statevector 2.340441 0.001799 2.376428\n", "numpy 8 civector 0.005615 0.003976 0.085136\n", "jax 12 statevector 7.562686 0.005127 7.665227\n", "numpy 12 civector 0.017621 0.008579 0.189198" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.loc[interesting_indices]" ] }, { "cell_type": "markdown", "id": "74cf0731-4ccf-4320-8ce3-5d7843820e05", "metadata": {}, "source": [ "For every system size tested, JAX + statevector is the fastest in terms of run time.\n", "\n", "However, if the staging time is taken into account, then NumPy + civector is most efficient.\n", "\n", "We note that the conclusion here is only valid for system size <= 16 qubits. \n", "For larger system CuPy + civector-large is the most scalable choice." ] }, { "cell_type": "code", "execution_count": null, "id": "ad9a8940-9e87-4e26-9b6e-e6d8d8eeb15b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }