Benchmark engines

Overview

In this notebook, we will test the run time of all engines in TenCirChem, along with their interplay with different backends.

The hydrogen chain system is used as the benchmark platform. The benchmarked system size is from 2 atoms to 6 atoms.

Setup

[1]:
import time

import numpy as np
import pandas as pd

from tencirchem import UCCSD, set_backend
from tencirchem.molecule import h_chain
[2]:
n_h_list = list(range(2, 8, 2))
uccsd_list = [UCCSD(h_chain(n_h)) for n_h in n_h_list]
params_list = [np.random.rand(uccsd.n_params) for uccsd in uccsd_list]
[3]:
# tensornetwork and statevector engine are only compatible with JAX backend
jax_engines = ["tensornetwork", "statevector", "civector", "civector-large", "pyscf"]
numpy_engines = ["civector", "civector-large", "pyscf"]
cupy_engines = numpy_engines
tested_engines_list = [jax_engines, numpy_engines, cupy_engines]

Benchmark

[5]:
time_data = []
for backend, tested_engines in zip(["jax", "numpy", "cupy"], tested_engines_list):
    set_backend(backend)
    for uccsd, params in zip(uccsd_list, params_list):
        for engine in tested_engines:
            # dry run first. Let it compile or build caches
            time1 = time.time()
            uccsd.energy_and_grad(params, engine=engine)
            time2 = time.time()
            staging_time = time2 - time1
            # several real runs. Assuming `n_run` evaluations during the optimization
            n_run = 20
            for i in range(n_run):
                uccsd.energy_and_grad(params, engine=engine)
            run_time = (time.time() - time2) / n_run
            item = (backend, uccsd.n_qubits, engine, staging_time, run_time, staging_time + n_run * run_time)
            print(item)
            time_data.append(item)
('jax', 4, 'tensornetwork', 0.6734888553619385, 0.0013010025024414063, 0.6995089054107666)
('jax', 4, 'statevector', 0.5532455444335938, 0.0008256316184997558, 0.5697581768035889)
('jax', 4, 'civector', 0.6092183589935303, 0.0061431884765625, 0.7320821285247803)
('jax', 4, 'civector-large', 0.6665019989013672, 0.0064354419708251955, 0.7952108383178711)
('jax', 4, 'pyscf', 0.0294039249420166, 0.016276955604553223, 0.35494303703308105)
('jax', 8, 'tensornetwork', 4.530415773391724, 0.003479158878326416, 4.599998950958252)
('jax', 8, 'statevector', 2.3404414653778076, 0.001799321174621582, 2.3764278888702393)
('jax', 8, 'civector', 0.6547484397888184, 0.007375049591064453, 0.8022494316101074)
('jax', 8, 'civector-large', 1.63521146774292, 0.006232154369354248, 1.7598545551300049)
('jax', 8, 'pyscf', 0.1072854995727539, 0.09403668642044068, 1.9880192279815674)
('jax', 12, 'tensornetwork', 24.58689785003662, 0.05817370414733887, 25.7503719329834)
('jax', 12, 'statevector', 7.56268572807312, 0.005127060413360596, 7.665226936340332)
('jax', 12, 'civector', 0.9048683643341064, 0.01238323450088501, 1.1525330543518066)
('jax', 12, 'civector-large', 8.32372498512268, 0.00833052396774292, 8.490335464477539)
('jax', 12, 'pyscf', 0.4856069087982178, 0.4437950611114502, 9.361508131027222)
('numpy', 4, 'civector', 0.0034182071685791016, 0.003229069709777832, 0.06799960136413574)
('numpy', 4, 'civector-large', 0.0037720203399658203, 0.003747880458831787, 0.07872962951660156)
('numpy', 4, 'pyscf', 0.01477956771850586, 0.014781630039215088, 0.3104121685028076)
('numpy', 8, 'civector', 0.005614757537841797, 0.003976082801818848, 0.08513641357421875)
('numpy', 8, 'civector-large', 0.0069615840911865234, 0.006995594501495362, 0.14687347412109375)
('numpy', 8, 'pyscf', 0.09161663055419922, 0.09203903675079346, 1.9323973655700684)
('numpy', 12, 'civector', 0.01762080192565918, 0.008578836917877197, 0.18919754028320312)
('numpy', 12, 'civector-large', 0.02767157554626465, 0.02590758800506592, 0.545823335647583)
('numpy', 12, 'pyscf', 0.4722292423248291, 0.443195104598999, 9.33613133430481)
('cupy', 4, 'civector', 0.32034993171691895, 0.008597338199615478, 0.4922966957092285)
('cupy', 4, 'civector-large', 0.02430891990661621, 0.014339327812194824, 0.3110954761505127)
('cupy', 4, 'pyscf', 0.01539301872253418, 0.01657602787017822, 0.34691357612609863)
('cupy', 8, 'civector', 0.036718130111694336, 0.01722104549407959, 0.38113903999328613)
('cupy', 8, 'civector-large', 0.05296659469604492, 0.052764499187469484, 1.1082565784454346)
('cupy', 8, 'pyscf', 0.09645533561706543, 0.09423872232437133, 1.9812297821044922)
('cupy', 12, 'civector', 0.1256849765777588, 0.04901479482650757, 1.1059808731079102)
('cupy', 12, 'civector-large', 0.18027353286743164, 0.1753893733024597, 3.688060998916626)
('cupy', 12, 'pyscf', 0.4615364074707031, 0.45049295425415037, 9.471395492553711)

Results and Discussion

[6]:
df = pd.DataFrame(
    time_data, columns=["backend", "qubits", "engine", "staging time", "run time", "total time"]
).set_index(["backend", "qubits", "engine"])
df
[6]:
staging time run time total time
backend qubits engine
jax 4 tensornetwork 0.673489 0.001301 0.699509
statevector 0.553246 0.000826 0.569758
civector 0.609218 0.006143 0.732082
civector-large 0.666502 0.006435 0.795211
pyscf 0.029404 0.016277 0.354943
8 tensornetwork 4.530416 0.003479 4.599999
statevector 2.340441 0.001799 2.376428
civector 0.654748 0.007375 0.802249
civector-large 1.635211 0.006232 1.759855
pyscf 0.107285 0.094037 1.988019
12 tensornetwork 24.586898 0.058174 25.750372
statevector 7.562686 0.005127 7.665227
civector 0.904868 0.012383 1.152533
civector-large 8.323725 0.008331 8.490335
pyscf 0.485607 0.443795 9.361508
numpy 4 civector 0.003418 0.003229 0.068000
civector-large 0.003772 0.003748 0.078730
pyscf 0.014780 0.014782 0.310412
8 civector 0.005615 0.003976 0.085136
civector-large 0.006962 0.006996 0.146873
pyscf 0.091617 0.092039 1.932397
12 civector 0.017621 0.008579 0.189198
civector-large 0.027672 0.025908 0.545823
pyscf 0.472229 0.443195 9.336131
cupy 4 civector 0.320350 0.008597 0.492297
civector-large 0.024309 0.014339 0.311095
pyscf 0.015393 0.016576 0.346914
8 civector 0.036718 0.017221 0.381139
civector-large 0.052967 0.052764 1.108257
pyscf 0.096455 0.094239 1.981230
12 civector 0.125685 0.049015 1.105981
civector-large 0.180274 0.175389 3.688061
pyscf 0.461536 0.450493 9.471395

The table contains rich information, but conclusion is not easily drawn.

Next, we find out the best option for each system size.

[7]:
interesting_indices = []
for i, ddf in df.groupby("qubits"):
    run_time_idx = ddf["run time"].idxmin()
    total_time_idx = ddf["total time"].idxmin()
    print(run_time_idx, total_time_idx)
    interesting_indices.extend([run_time_idx, total_time_idx])
('jax', 4, 'statevector') ('numpy', 4, 'civector')
('jax', 8, 'statevector') ('numpy', 8, 'civector')
('jax', 12, 'statevector') ('numpy', 12, 'civector')
[8]:
df.loc[interesting_indices]
[8]:
staging time run time total time
backend qubits engine
jax 4 statevector 0.553246 0.000826 0.569758
numpy 4 civector 0.003418 0.003229 0.068000
jax 8 statevector 2.340441 0.001799 2.376428
numpy 8 civector 0.005615 0.003976 0.085136
jax 12 statevector 7.562686 0.005127 7.665227
numpy 12 civector 0.017621 0.008579 0.189198

For every system size tested, JAX + statevector is the fastest in terms of run time.

However, if the staging time is taken into account, then NumPy + civector is most efficient.

We note that the conclusion here is only valid for system size <= 16 qubits. For larger system CuPy + civector-large is the most scalable choice.

[ ]: