Skip to content

Commit

Permalink
[WIP] tempretize python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
KowerKoint authored and KowerKoint committed Nov 20, 2024
1 parent 1641b90 commit 239438d
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 96 deletions.
25 changes: 19 additions & 6 deletions include/scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ namespace internal {
"Get string representation of the gate.")

nb::class_<Gate<double>> gate_base_def_double;
nb::class_<Gate<float>> gate_base_def_float;

#define DEF_GATE(GATE_TYPE, FLOAT, DESCRIPTION) \
::scaluq::internal::gate_base_def_##FLOAT.def(nb::init<GATE_TYPE<FLOAT>>(), \
Expand All @@ -354,6 +355,7 @@ nb::class_<Gate<double>> gate_base_def_double;
"\n\n.. note:: Upcast is required to use gate-general functions (ex: add to Circuit).") \
.def(nb::init<Gate<FLOAT>>())

template <std::floating_point Fp>
void bind_gate_gate_hpp(nb::module_& m) {
nb::enum_<GateType>(m, "GateType", "Enum of Gate Type.")
.value("I", GateType::I)
Expand Down Expand Up @@ -384,12 +386,23 @@ void bind_gate_gate_hpp(nb::module_& m) {
.value("Pauli", GateType::Pauli)
.value("PauliRotation", GateType::PauliRotation);

gate_base_def_double =
DEF_GATE_BASE(Gate,
double,
"General class of QuantumGate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.")
.def(nb::init<Gate<double>>(), "Just copy shallowly.");
if constexpr (std::is_same_v<Fp, double>) {
gate_base_def_double =
DEF_GATE_BASE(Gate,
double,
"General class of QuantumGate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.")
.def(nb::init<Gate<double>>(), "Just copy shallowly.");
} else if constexpr (std::is_same_v<Fp, float>) {
gate_base_def_float =
DEF_GATE_BASE(Gate,
float,
"General class of QuantumGate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.")
.def(nb::init<Gate<float>>(), "Just copy shallowly.");
} else {
static_assert(internal::lazy_false_v<void>);
}
}
} // namespace internal
#endif
Expand Down
63 changes: 31 additions & 32 deletions include/scaluq/gate/gate_standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,48 +466,49 @@ using SwapGate = internal::GatePtr<internal::SwapGateImpl<Fp>>;

#ifdef SCALUQ_USE_NANOBIND
namespace internal {
template <std::floating_point Fp>
void bind_gate_gate_standard_hpp(nb::module_& m) {
DEF_GATE(IGate, double, "Specific class of Pauli-I gate.");
DEF_GATE(IGate, Fp, "Specific class of Pauli-I gate.");
DEF_GATE(GlobalPhaseGate,
double,
Fp,
"Specific class of gate, which rotate global phase, represented as "
"$e^{i\\mathrm{phase}}I$.")
.def(
"phase",
[](const GlobalPhaseGate<double>& gate) { return gate->phase(); },
[](const GlobalPhaseGate<Fp>& gate) { return gate->phase(); },
"Get `phase` property");
DEF_GATE(XGate, double, "Specific class of Pauli-X gate.");
DEF_GATE(YGate, double, "Specific class of Pauli-Y gate.");
DEF_GATE(ZGate, double, "Specific class of Pauli-Z gate.");
DEF_GATE(HGate, double, "Specific class of Hadamard gate.");
DEF_GATE(XGate, Fp, "Specific class of Pauli-X gate.");
DEF_GATE(YGate, Fp, "Specific class of Pauli-Y gate.");
DEF_GATE(ZGate, Fp, "Specific class of Pauli-Z gate.");
DEF_GATE(HGate, Fp, "Specific class of Hadamard gate.");
DEF_GATE(SGate,
double,
Fp,
"Specific class of S gate, represented as $\\begin { bmatrix }\n1 & 0\\\\\n0 &"
"i\n\\end{bmatrix}$.");
DEF_GATE(SdagGate, double, "Specific class of inverse of S gate.");
DEF_GATE(SdagGate, Fp, "Specific class of inverse of S gate.");
DEF_GATE(TGate,
double,
Fp,
"Specific class of T gate, represented as $\\begin { bmatrix }\n1 & 0\\\\\n0 &"
"e^{i\\pi/4}\n\\end{bmatrix}$.");
DEF_GATE(TdagGate, double, "Specific class of inverse of T gate.");
DEF_GATE(TdagGate, Fp, "Specific class of inverse of T gate.");
DEF_GATE(
SqrtXGate,
double,
Fp,
"Specific class of sqrt(X) gate, represented as $\\begin{ bmatrix }\n1+i & 1-i\\\\\n1-i "
"& 1+i\n\\end{bmatrix}$.");
DEF_GATE(SqrtXdagGate, double, "Specific class of inverse of sqrt(X) gate.");
DEF_GATE(SqrtXdagGate, Fp, "Specific class of inverse of sqrt(X) gate.");
DEF_GATE(SqrtYGate,
double,
Fp,
"Specific class of sqrt(Y) gate, represented as $\\begin{ bmatrix }\n1+i & -1-i "
"\\\\\n1+i & 1+i\n\\end{bmatrix}$.");
DEF_GATE(SqrtYdagGate, double, "Specific class of inverse of sqrt(Y) gate.");
DEF_GATE(SqrtYdagGate, Fp, "Specific class of inverse of sqrt(Y) gate.");
DEF_GATE(
P0Gate,
double,
Fp,
"Specific class of projection gate to $\\ket{0}$.\n\n.. note:: This gate is not unitary.");
DEF_GATE(
P1Gate,
double,
Fp,
"Specific class of projection gate to $\\ket{1}$.\n\n.. note:: This gate is not unitary.");

#define DEF_ROTATION_GATE(GATE_TYPE, FLOAT, DESCRIPTION) \
Expand All @@ -519,57 +520,55 @@ void bind_gate_gate_standard_hpp(nb::module_& m) {

DEF_ROTATION_GATE(
RXGate,
double,
Fp,
"Specific class of X rotation gate, represented as $e^{-i\\frac{\\mathrm{angle}}{2}X}$.");
DEF_ROTATION_GATE(
RYGate,
double,
Fp,
"Specific class of Y rotation gate, represented as $e^{-i\\frac{\\mathrm{angle}}{2}Y}$.");
DEF_ROTATION_GATE(
RZGate,
double,
Fp,
"Specific class of Z rotation gate, represented as $e^{-i\\frac{\\mathrm{angle}}{2}Z}$.");

DEF_GATE(U1Gate,
double,
Fp,
"Specific class of IBMQ's U1 Gate, which is a rotation abount Z-axis, "
"represented as "
"$\\begin{bmatrix}\n1 & 0\\\\\n0 & e^{i\\lambda}\n\\end{bmatrix}$.")
.def(
"lambda_",
[](const U1Gate<double>& gate) { return gate->lambda(); },
[](const U1Gate<Fp>& gate) { return gate->lambda(); },
"Get `lambda` property.");
DEF_GATE(U2Gate,
double,
Fp,
"Specific class of IBMQ's U2 Gate, which is a rotation about X+Z-axis, "
"represented as "
"$\\frac{1}{\\sqrt{2}} \\begin{bmatrix}1 & -e^{-i\\lambda}\\\\\n"
"e^{i\\phi} & e^{i(\\phi+\\lambda)}\n\\end{bmatrix}$.")
.def(
"phi", [](const U2Gate<double>& gate) { return gate->phi(); }, "Get `phi` property.")
"phi", [](const U2Gate<Fp>& gate) { return gate->phi(); }, "Get `phi` property.")
.def(
"lambda_",
[](const U2Gate<double>& gate) { return gate->lambda(); },
[](const U2Gate<Fp>& gate) { return gate->lambda(); },
"Get `lambda` property.");
DEF_GATE(U3Gate,
double,
Fp,
"Specific class of IBMQ's U3 Gate, which is a rotation abount 3 axis, "
"represented as "
"$\\begin{bmatrix}\n\\cos \\frac{\\theta}{2} & "
"-e^{i\\lambda}\\sin\\frac{\\theta}{2}\\\\\n"
"e^{i\\phi}\\sin\\frac{\\theta}{2} & "
"e^{i(\\phi+\\lambda)}\\cos\\frac{\\theta}{2}\n\\end{bmatrix}$.")
.def(
"theta",
[](const U3Gate<double>& gate) { return gate->theta(); },
"Get `theta` property.")
"theta", [](const U3Gate<Fp>& gate) { return gate->theta(); }, "Get `theta` property.")
.def(
"phi", [](const U3Gate<double>& gate) { return gate->phi(); }, "Get `phi` property.")
"phi", [](const U3Gate<Fp>& gate) { return gate->phi(); }, "Get `phi` property.")
.def(
"lambda_",
[](const U3Gate<double>& gate) { return gate->lambda(); },
[](const U3Gate<Fp>& gate) { return gate->lambda(); },
"Get `lambda` property.");
DEF_GATE(SwapGate, double, "Specific class of two-qubit swap gate.");
DEF_GATE(SwapGate, Fp, "Specific class of two-qubit swap gate.");
}
} // namespace internal
#endif
Expand Down
28 changes: 22 additions & 6 deletions include/scaluq/gate/param_gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ using ParamGate = internal::ParamGatePtr<internal::ParamGateBase<Fp>>;
#ifdef SCALUQ_USE_NANOBIND
namespace internal {
#define DEF_PARAM_GATE_BASE(PARAM_GATE_TYPE, FLOAT, DESCRIPTION) \
nb::class_<PARAM_GATE_TYPE<FLOAT>>(m, #PARAM_GATE_TYPE "_" #FLOAT, DESCRIPTION) \
nb::class_<PARAM_GATE_TYPE<FLOAT>>(m, #PARAM_GATE_TYPE, DESCRIPTION) \
.def("param_gate_type", \
&PARAM_GATE_TYPE<FLOAT>::param_gate_type, \
"Get parametric gate type as `ParamGateType` enum.") \
Expand Down Expand Up @@ -219,6 +219,7 @@ namespace internal {
"Get matrix representation of the gate with holding the parameter.")

nb::class_<ParamGate<double>> param_gate_base_def_double;
nb::class_<ParamGate<float>> param_gate_base_def_float;

#define DEF_PARAM_GATE(PARAM_GATE_TYPE, FLOAT, DESCRIPTION) \
::scaluq::internal::param_gate_base_def_##FLOAT.def(nb::init<PARAM_GATE_TYPE<FLOAT>>(), \
Expand All @@ -230,18 +231,33 @@ nb::class_<ParamGate<double>> param_gate_base_def_double;
"\n\n.. note:: Upcast is required to use gate-general functions (ex: add to Circuit).") \
.def(nb::init<ParamGate<FLOAT>>())

template <std::floating_point Fp>
void bind_gate_param_gate_hpp(nb::module_& m) {
nb::enum_<ParamGateType>(m, "ParamGateType", "Enum of ParamGate Type.")
.value("ParamRX", ParamGateType::ParamRX)
.value("ParamRY", ParamGateType::ParamRY)
.value("ParamRZ", ParamGateType::ParamRZ)
.value("ParamPauliRotation", ParamGateType::ParamPauliRotation);

param_gate_base_def_double = DEF_PARAM_GATE_BASE(
ParamGate,
double,
"General class of parametric quantum gate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.");
if constexpr (std::is_same_v<Fp, double>) {
param_gate_base_def_double =
DEF_PARAM_GATE_BASE(
ParamGate,
double,
"General class of parametric quantum gate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.")
.def(nb::init<ParamGate<double>>(), "Just copy shallowly.");
} else if constexpr (std::is_same_v<Fp, float>) {
param_gate_base_def_float =
DEF_PARAM_GATE_BASE(
ParamGate,
float,
"General class of parametric quantum gate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.")
.def(nb::init<ParamGate<float>>(), "Just copy shallowly.");
} else {
static_asert(internal::lazy_false_v<void>);
}
}

} // namespace internal
Expand Down
50 changes: 25 additions & 25 deletions include/scaluq/state/state_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ class StateVector {

#ifdef SCALUQ_USE_NANOBIND
namespace internal {
template <std::floating_point Fp>
void bind_state_state_vector_hpp(nb::module_& m) {
nb::class_<StateVector<double>>(
nb::class_<StateVector<Fp>>(
m,
"StateVector",
"Vector representation of quantum state.\n\n.. note:: Qubit index is "
Expand All @@ -98,13 +99,13 @@ void bind_state_state_vector_hpp(nb::module_& m) {
.def(nb::init<std::uint64_t>(),
"Construct state vector with specified qubits, initialized with computational "
"basis $\\ket{0\\dots0}$.")
.def(nb::init<const StateVector<double>&>(),
.def(nb::init<const StateVector<Fp>&>(),
"Constructing state vector by copying other state.")
.def_static(
"Haar_random_state",
[](std::uint64_t n_qubits, std::optional<std::uint64_t> seed) {
return StateVector<double>::Haar_random_state(
n_qubits, seed.value_or(std::random_device{}()));
return StateVector<Fp>::Haar_random_state(n_qubits,
seed.value_or(std::random_device{}()));
},
"n_qubits"_a,
"seed"_a = std::nullopt,
Expand Down Expand Up @@ -143,70 +144,69 @@ void bind_state_state_vector_hpp(nb::module_& m) {
.build_as_google_style()
.c_str())
.def("set_amplitude_at",
&StateVector<double>::set_amplitude_at,
&StateVector<Fp>::set_amplitude_at,
"Manually set amplitude at one index.")
.def("get_amplitude_at",
&StateVector<double>::get_amplitude_at,
&StateVector<Fp>::get_amplitude_at,
"Get amplitude at one index.\n\n.. note:: If you want to get all amplitudes, you "
"should "
"use `StateVector::get_amplitudes()`.")
.def("set_zero_state",
&StateVector<double>::set_zero_state,
&StateVector<Fp>::set_zero_state,
"Initialize with computational basis $\\ket{00\\dots0}$.")
.def("set_zero_norm_state",
&StateVector<double>::set_zero_norm_state,
&StateVector<Fp>::set_zero_norm_state,
"Initialize with 0 (null vector).")
.def("set_computational_basis",
&StateVector<double>::set_computational_basis,
&StateVector<Fp>::set_computational_basis,
"Initialize with computational basis \\ket{\\mathrm{basis}}.")
.def("get_amplitudes",
&StateVector<double>::get_amplitudes,
&StateVector<Fp>::get_amplitudes,
"Get all amplitudes with as `list[complex]`.")
.def("n_qubits", &StateVector<double>::n_qubits, "Get num of qubits.")
.def("n_qubits", &StateVector<Fp>::n_qubits, "Get num of qubits.")
.def("dim",
&StateVector<double>::dim,
&StateVector<Fp>::dim,
"Get dimension of the vector ($=2^\\mathrm{n\\_qubits}$).")
.def("get_squared_norm",
&StateVector<double>::get_squared_norm,
&StateVector<Fp>::get_squared_norm,
"Get squared norm of the state. $\\braket{\\psi|\\psi}$.")
.def("normalize",
&StateVector<double>::normalize,
&StateVector<Fp>::normalize,
"Normalize state (let $\\braket{\\psi|\\psi} = 1$ by multiplying coef).")
.def("get_zero_probability",
&StateVector<double>::get_zero_probability,
&StateVector<Fp>::get_zero_probability,
"Get the probability to observe $\\ket{0}$ at specified index.")
.def("get_marginal_probability",
&StateVector<double>::get_marginal_probability,
&StateVector<Fp>::get_marginal_probability,
"Get the marginal probability to observe as specified. Specify the result as n-length "
"list. `0` and `1` represent the qubit is observed and get the value. `2` represents "
"the qubit is not observed.")
.def("get_entropy", &StateVector<double>::get_entropy, "Get the entropy of the vector.")
.def("get_entropy", &StateVector<Fp>::get_entropy, "Get the entropy of the vector.")
.def("add_state_vector_with_coef",
&StateVector<double>::add_state_vector_with_coef,
&StateVector<Fp>::add_state_vector_with_coef,
"add other state vector with multiplying the coef and make superposition. "
"$\\ket{\\mathrm{this}}\\leftarrow\\ket{\\mathrm{this}}+\\mathrm{coef}"
"\\ket{\\mathrm{"
"state}}$.")
.def("multiply_coef",
&StateVector<double>::multiply_coef,
&StateVector<Fp>::multiply_coef,
"Multiply coef. "
"$\\ket{\\mathrm{this}}\\leftarrow\\mathrm{coef}\\ket{\\mathrm{this}}$.")
.def(
"sampling",
[](const StateVector<double>& state,
[](const StateVector<Fp>& state,
std::uint64_t sampling_count,
std::optional<std::uint64_t> seed) {
return state.sampling(sampling_count, seed.value_or(std::random_device{}()));
},
"sampling_count"_a,
"seed"_a = std::nullopt,
"Sampling specified times. Result is `list[int]` with the `sampling_count` length.")
.def("to_string", &StateVector<double>::to_string, "Information as `str`.")
.def(
"load", &StateVector<double>::load, "Load amplitudes of `list[int]` with `dim` length.")
.def("__str__", &StateVector<double>::to_string, "Information as `str`.")
.def("to_string", &StateVector<Fp>::to_string, "Information as `str`.")
.def("load", &StateVector<Fp>::load, "Load amplitudes of `list[int]` with `dim` length.")
.def("__str__", &StateVector<Fp>::to_string, "Information as `str`.")
.def_ro_static("UNMEASURED",
&StateVector<double>::UNMEASURED,
&StateVector<Fp>::UNMEASURED,
"Constant used for `StateVector::get_marginal_probability` to express the "
"the qubit is not measured.");
}
Expand Down
Loading

0 comments on commit 239438d

Please sign in to comment.