Skip to content

Commit

Permalink
Truncate save_expval (Qiskit#2216)
Browse files Browse the repository at this point in the history
* truncate save_expval

* fix truncation

* fix truncation

* add num_original_qubits to aer_circuit to get num_qubits without ancilla qubits

* Fix adding qubitset

* add test case, release note and fix docs

* fix doc

* fix doc

* fix doc

* fix doc

* fix doc

* fix doc

* fix doc

* fix doc

* change truncation strategy

* format

* remove print

* no truncation when circuit is empty

* Update VERSION.txt

revert to 0.15.0

---------

Co-authored-by: Hiroshi Horii <[email protected]>
  • Loading branch information
doichanj and hhorii committed Sep 12, 2024
1 parent e384a36 commit 8dee229
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 5 deletions.
8 changes: 8 additions & 0 deletions releasenotes/notes/truncate_expval-7fa814c732cca8db.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
fixes:
- |
This fix truncates qubits of save_expval operation when EstimatorV2
is made from existing backends using `from_backend`.
By transpiling on the existing backends, ancilla qubits are filled
to all the qubits that causes memory error on the simulator.
So Aer removes unused qubits for save_expval operation.
55 changes: 50 additions & 5 deletions src/framework/circuit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Circuit {
uint_t num_qubits = 0; // maximum number of qubits needed for ops
uint_t num_memory = 0; // maximum number of memory clbits needed for ops
uint_t num_registers = 0; // maximum number of registers clbits needed for ops
uint_t num_original_qubits = 0; // number of qubits without ancilla qubits

// Measurement params
bool has_conditional = false; // True if any ops are conditional
Expand Down Expand Up @@ -419,7 +420,20 @@ void Circuit::reset_metadata() {
void Circuit::add_op_metadata(const Op &op) {
has_conditional |= op.conditional;
opset_.insert(op);
qubitset_.insert(op.qubits.begin(), op.qubits.end());
if (!qubitset_.empty() &&
(op.type == OpType::save_expval || op.type == OpType::save_expval_var)) {
for (int_t j = 0; j < op.expval_params.size(); j++) {
const std::string &pauli = std::get<0>(op.expval_params[j]);
for (int_t i = 0; i < op.qubits.size(); i++) {
// add qubit with non-I operator
if (pauli[pauli.size() - 1 - i] != 'I') {
qubitset_.insert(op.qubits[i]);
}
}
}
} else {
qubitset_.insert(op.qubits.begin(), op.qubits.end());
}
memoryset_.insert(op.memory.begin(), op.memory.end());
registerset_.insert(op.registers.begin(), op.registers.end());

Expand Down Expand Up @@ -589,6 +603,19 @@ void Circuit::set_params(bool truncation) {
}
if (remapped_qubits) {
remap_qubits(ops[pos]);
} else if (truncation && qubitmap_.size() < ops[pos].qubits.size()) {
// truncate save_expval here when remap is not needed
if (ops[pos].type == OpType::save_expval ||
ops[pos].type == OpType::save_expval_var) {
int_t nparams = ops[pos].expval_params.size();
for (int_t i = 0; i < nparams; i++) {
std::string &pauli = std::get<0>(ops[pos].expval_params[i]);
std::string new_pauli;
new_pauli.assign(pauli.end() - qubitmap_.size(), pauli.end());
pauli = new_pauli;
}
ops[pos].qubits.resize(qubitmap_.size());
}
}
if (pos != op_idx) {
ops[op_idx] = std::move(ops[pos]);
Expand Down Expand Up @@ -653,11 +680,29 @@ void Circuit::set_params(bool truncation) {
}

void Circuit::remap_qubits(Op &op) const {
reg_t new_qubits;
for (auto &qubit : op.qubits) {
new_qubits.push_back(qubitmap_.at(qubit));
// truncate save_expval
if (op.type == OpType::save_expval || op.type == OpType::save_expval_var) {
int_t nparams = op.expval_params.size();
for (int_t i = 0; i < nparams; i++) {
std::string &pauli = std::get<0>(op.expval_params[i]);
std::string new_pauli;
new_pauli.resize(qubitmap_.size());
for (auto q = qubitmap_.cbegin(); q != qubitmap_.cend(); q++) {
new_pauli[qubitmap_.size() - 1 - q->second] =
pauli[pauli.size() - 1 - q->first];
}
pauli = new_pauli;
}
for (int_t i = 0; i < qubitmap_.size(); i++) {
op.qubits[i] = i;
}
} else {
reg_t new_qubits;
for (auto &qubit : op.qubits) {
new_qubits.push_back(qubitmap_.at(qubit));
}
op.qubits = std::move(new_qubits);
}
op.qubits = std::move(new_qubits);
}

bool Circuit::check_result_ancestor(
Expand Down
38 changes: 38 additions & 0 deletions test/terra/primitives/test_estimator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from test.terra.common import QiskitAerTestCase

import numpy as np
from qiskit import transpile
from qiskit.circuit import Parameter, QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
from qiskit.primitives import StatevectorEstimator
Expand All @@ -26,6 +27,7 @@
from qiskit.primitives.containers.observables_array import ObservablesArray
from qiskit.quantum_info import SparsePauliOp
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit.providers.fake_provider import GenericBackendV2

from qiskit_aer import AerSimulator
from qiskit_aer.primitives import EstimatorV2
Expand Down Expand Up @@ -407,6 +409,42 @@ def test_metadata(self):
{"target_precision": 0.1, "circuit_metadata": qc2.metadata},
)

def test_truncate(self):
"""Test for truncation of save_expval"""
qc = QuantumCircuit(2, 2)
qc.h(0)
qc.cx(0, 1)
qc.append(RealAmplitudes(num_qubits=2, reps=2), [0, 1])
backend_2 = GenericBackendV2(num_qubits=2)
backend_5 = GenericBackendV2(num_qubits=5)

qc_2 = transpile(qc, backend_2, optimization_level=0)
qc_5 = transpile(qc, backend_5, optimization_level=0)

estimator_2 = EstimatorV2.from_backend(backend_2, options=self._options)
estimator_5 = EstimatorV2.from_backend(backend_5, options=self._options)

H1 = self.observable
H1_2 = H1.apply_layout(qc_2.layout)
H1_5 = H1.apply_layout(qc_5.layout)
theta1 = [0, 1, 1, 2, 3, 5]

result_2 = estimator_2.run(
[
(qc_2, [H1_2], [theta1]),
],
precision=0.01,
).result()
result_5 = estimator_5.run(
[
(qc_5, [H1_5], [theta1]),
],
precision=0.01,
).result()
self.assertAlmostEqual(
result_5[0].data["evs"][0], result_2[0].data["evs"][0], delta=self._rtol
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8dee229

Please sign in to comment.