Benchmark engines¶
Overview¶
In this notebook, we will test the run time of all engines in TenCirChem, along with their interplay with different backends.
The hydrogen chain system is used as the benchmark platform. The benchmarked system size is from 2 atoms to 6 atoms.
Setup¶
[1]:
import time
import numpy as np
import pandas as pd
from tencirchem import UCCSD, set_backend
from tencirchem.molecule import h_chain
[2]:
n_h_list = list(range(2, 8, 2))
uccsd_list = [UCCSD(h_chain(n_h)) for n_h in n_h_list]
params_list = [np.random.rand(uccsd.n_params) for uccsd in uccsd_list]
[3]:
# tensornetwork and statevector engine are only compatible with JAX backend
jax_engines = ["tensornetwork", "statevector", "civector", "civector-large", "pyscf"]
numpy_engines = ["civector", "civector-large", "pyscf"]
cupy_engines = numpy_engines
tested_engines_list = [jax_engines, numpy_engines, cupy_engines]
Benchmark¶
[5]:
time_data = []
for backend, tested_engines in zip(["jax", "numpy", "cupy"], tested_engines_list):
set_backend(backend)
for uccsd, params in zip(uccsd_list, params_list):
for engine in tested_engines:
# dry run first. Let it compile or build caches
time1 = time.time()
uccsd.energy_and_grad(params, engine=engine)
time2 = time.time()
staging_time = time2 - time1
# several real runs. Assuming `n_run` evaluations during the optimization
n_run = 20
for i in range(n_run):
uccsd.energy_and_grad(params, engine=engine)
run_time = (time.time() - time2) / n_run
item = (backend, uccsd.n_qubits, engine, staging_time, run_time, staging_time + n_run * run_time)
print(item)
time_data.append(item)
('jax', 4, 'tensornetwork', 0.6734888553619385, 0.0013010025024414063, 0.6995089054107666)
('jax', 4, 'statevector', 0.5532455444335938, 0.0008256316184997558, 0.5697581768035889)
('jax', 4, 'civector', 0.6092183589935303, 0.0061431884765625, 0.7320821285247803)
('jax', 4, 'civector-large', 0.6665019989013672, 0.0064354419708251955, 0.7952108383178711)
('jax', 4, 'pyscf', 0.0294039249420166, 0.016276955604553223, 0.35494303703308105)
('jax', 8, 'tensornetwork', 4.530415773391724, 0.003479158878326416, 4.599998950958252)
('jax', 8, 'statevector', 2.3404414653778076, 0.001799321174621582, 2.3764278888702393)
('jax', 8, 'civector', 0.6547484397888184, 0.007375049591064453, 0.8022494316101074)
('jax', 8, 'civector-large', 1.63521146774292, 0.006232154369354248, 1.7598545551300049)
('jax', 8, 'pyscf', 0.1072854995727539, 0.09403668642044068, 1.9880192279815674)
('jax', 12, 'tensornetwork', 24.58689785003662, 0.05817370414733887, 25.7503719329834)
('jax', 12, 'statevector', 7.56268572807312, 0.005127060413360596, 7.665226936340332)
('jax', 12, 'civector', 0.9048683643341064, 0.01238323450088501, 1.1525330543518066)
('jax', 12, 'civector-large', 8.32372498512268, 0.00833052396774292, 8.490335464477539)
('jax', 12, 'pyscf', 0.4856069087982178, 0.4437950611114502, 9.361508131027222)
('numpy', 4, 'civector', 0.0034182071685791016, 0.003229069709777832, 0.06799960136413574)
('numpy', 4, 'civector-large', 0.0037720203399658203, 0.003747880458831787, 0.07872962951660156)
('numpy', 4, 'pyscf', 0.01477956771850586, 0.014781630039215088, 0.3104121685028076)
('numpy', 8, 'civector', 0.005614757537841797, 0.003976082801818848, 0.08513641357421875)
('numpy', 8, 'civector-large', 0.0069615840911865234, 0.006995594501495362, 0.14687347412109375)
('numpy', 8, 'pyscf', 0.09161663055419922, 0.09203903675079346, 1.9323973655700684)
('numpy', 12, 'civector', 0.01762080192565918, 0.008578836917877197, 0.18919754028320312)
('numpy', 12, 'civector-large', 0.02767157554626465, 0.02590758800506592, 0.545823335647583)
('numpy', 12, 'pyscf', 0.4722292423248291, 0.443195104598999, 9.33613133430481)
('cupy', 4, 'civector', 0.32034993171691895, 0.008597338199615478, 0.4922966957092285)
('cupy', 4, 'civector-large', 0.02430891990661621, 0.014339327812194824, 0.3110954761505127)
('cupy', 4, 'pyscf', 0.01539301872253418, 0.01657602787017822, 0.34691357612609863)
('cupy', 8, 'civector', 0.036718130111694336, 0.01722104549407959, 0.38113903999328613)
('cupy', 8, 'civector-large', 0.05296659469604492, 0.052764499187469484, 1.1082565784454346)
('cupy', 8, 'pyscf', 0.09645533561706543, 0.09423872232437133, 1.9812297821044922)
('cupy', 12, 'civector', 0.1256849765777588, 0.04901479482650757, 1.1059808731079102)
('cupy', 12, 'civector-large', 0.18027353286743164, 0.1753893733024597, 3.688060998916626)
('cupy', 12, 'pyscf', 0.4615364074707031, 0.45049295425415037, 9.471395492553711)
Results and Discussion¶
[6]:
df = pd.DataFrame(
time_data, columns=["backend", "qubits", "engine", "staging time", "run time", "total time"]
).set_index(["backend", "qubits", "engine"])
df
[6]:
staging time | run time | total time | |||
---|---|---|---|---|---|
backend | qubits | engine | |||
jax | 4 | tensornetwork | 0.673489 | 0.001301 | 0.699509 |
statevector | 0.553246 | 0.000826 | 0.569758 | ||
civector | 0.609218 | 0.006143 | 0.732082 | ||
civector-large | 0.666502 | 0.006435 | 0.795211 | ||
pyscf | 0.029404 | 0.016277 | 0.354943 | ||
8 | tensornetwork | 4.530416 | 0.003479 | 4.599999 | |
statevector | 2.340441 | 0.001799 | 2.376428 | ||
civector | 0.654748 | 0.007375 | 0.802249 | ||
civector-large | 1.635211 | 0.006232 | 1.759855 | ||
pyscf | 0.107285 | 0.094037 | 1.988019 | ||
12 | tensornetwork | 24.586898 | 0.058174 | 25.750372 | |
statevector | 7.562686 | 0.005127 | 7.665227 | ||
civector | 0.904868 | 0.012383 | 1.152533 | ||
civector-large | 8.323725 | 0.008331 | 8.490335 | ||
pyscf | 0.485607 | 0.443795 | 9.361508 | ||
numpy | 4 | civector | 0.003418 | 0.003229 | 0.068000 |
civector-large | 0.003772 | 0.003748 | 0.078730 | ||
pyscf | 0.014780 | 0.014782 | 0.310412 | ||
8 | civector | 0.005615 | 0.003976 | 0.085136 | |
civector-large | 0.006962 | 0.006996 | 0.146873 | ||
pyscf | 0.091617 | 0.092039 | 1.932397 | ||
12 | civector | 0.017621 | 0.008579 | 0.189198 | |
civector-large | 0.027672 | 0.025908 | 0.545823 | ||
pyscf | 0.472229 | 0.443195 | 9.336131 | ||
cupy | 4 | civector | 0.320350 | 0.008597 | 0.492297 |
civector-large | 0.024309 | 0.014339 | 0.311095 | ||
pyscf | 0.015393 | 0.016576 | 0.346914 | ||
8 | civector | 0.036718 | 0.017221 | 0.381139 | |
civector-large | 0.052967 | 0.052764 | 1.108257 | ||
pyscf | 0.096455 | 0.094239 | 1.981230 | ||
12 | civector | 0.125685 | 0.049015 | 1.105981 | |
civector-large | 0.180274 | 0.175389 | 3.688061 | ||
pyscf | 0.461536 | 0.450493 | 9.471395 |
The table contains rich information, but conclusion is not easily drawn.
Next, we find out the best option for each system size.
[7]:
interesting_indices = []
for i, ddf in df.groupby("qubits"):
run_time_idx = ddf["run time"].idxmin()
total_time_idx = ddf["total time"].idxmin()
print(run_time_idx, total_time_idx)
interesting_indices.extend([run_time_idx, total_time_idx])
('jax', 4, 'statevector') ('numpy', 4, 'civector')
('jax', 8, 'statevector') ('numpy', 8, 'civector')
('jax', 12, 'statevector') ('numpy', 12, 'civector')
[8]:
df.loc[interesting_indices]
[8]:
staging time | run time | total time | |||
---|---|---|---|---|---|
backend | qubits | engine | |||
jax | 4 | statevector | 0.553246 | 0.000826 | 0.569758 |
numpy | 4 | civector | 0.003418 | 0.003229 | 0.068000 |
jax | 8 | statevector | 2.340441 | 0.001799 | 2.376428 |
numpy | 8 | civector | 0.005615 | 0.003976 | 0.085136 |
jax | 12 | statevector | 7.562686 | 0.005127 | 7.665227 |
numpy | 12 | civector | 0.017621 | 0.008579 | 0.189198 |
For every system size tested, JAX + statevector is the fastest in terms of run time.
However, if the staging time is taken into account, then NumPy + civector is most efficient.
We note that the conclusion here is only valid for system size <= 16 qubits. For larger system CuPy + civector-large is the most scalable choice.
[ ]: