From cb59b42c450194d32c1001ff0ad2b16454cce688 Mon Sep 17 00:00:00 2001 From: Wonyong Kim Date: Mon, 17 Jun 2024 00:26:25 +0900 Subject: [PATCH] refac: use `c::base::xxx_cast()` for casting fields and points --- benchmark/fft/fft_runner.h | 21 ++++++++----------- benchmark/msm/msm_benchmark.cc | 1 + benchmark/msm/msm_benchmark_gpu.cc | 1 + benchmark/msm/msm_runner.h | 16 +++++++------- .../poseidon/poseidon_benchmark_runner.h | 2 +- .../poseidon2/poseidon2_benchmark_runner.h | 2 +- .../math/elliptic_curves/generator/msm.cc.tpl | 1 + .../elliptic_curves/msm/msm_gpu_replay.cc | 7 +++---- .../elliptic_curves/msm/msm_gpu_unittest.cc | 6 ++---- .../elliptic_curves/msm/msm_input_provider.h | 7 +++---- .../math/elliptic_curves/msm/msm_unittest.cc | 14 ++++++------- .../bn254_univariate_evaluations.cc | 3 ++- .../bn254_univariate_rational_evaluations.cc | 7 ++++--- tachyon/c/zk/plonk/halo2/BUILD.bazel | 1 + .../c/zk/plonk/halo2/bn254_argument_data.cc | 3 ++- tachyon/c/zk/plonk/halo2/bn254_gwc_prover.cc | 1 + .../c/zk/plonk/halo2/bn254_shplonk_prover.cc | 1 + .../c/zk/plonk/halo2/kzg_family_prover_impl.h | 4 ++-- 18 files changed, 49 insertions(+), 49 deletions(-) diff --git a/benchmark/fft/fft_runner.h b/benchmark/fft/fft_runner.h index 5f7fbd62c..c9fec2649 100644 --- a/benchmark/fft/fft_runner.h +++ b/benchmark/fft/fft_runner.h @@ -15,6 +15,7 @@ #include "tachyon/base/functional/functor_traits.h" #include "tachyon/base/time/time.h" #include "tachyon/c/math/elliptic_curves/bn/bn254/fr.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h" #include "tachyon/c/math/polynomials/univariate/bn254_univariate_evaluation_domain.h" namespace tachyon { @@ -74,12 +75,10 @@ class FFTRunner { std::unique_ptr ret; if constexpr (std::is_same_v) { const F omega_inv = domains_[i]->group_gen_inv(); - ret.reset(reinterpret_cast( - fn(reinterpret_cast( - (*polys_)[i].evaluations().data()), - (*polys_)[i].Degree(), - reinterpret_cast(&omega_inv), - exponents[i], &duration_in_us))); + ret.reset(c::base::native_cast( + fn(c::base::c_cast((*polys_)[i].evaluations().data()), + (*polys_)[i].Degree(), c::base::c_cast(&omega_inv), exponents[i], + &duration_in_us))); std::vector res_vec(ret.get(), ret.get() + (*polys_)[i].Degree()); results->emplace_back( typename RetPoly::Coefficients(std::move(res_vec))); @@ -87,12 +86,10 @@ class FFTRunner { } else if constexpr (std::is_same_v) { const F omega = domains_[i]->group_gen(); - ret.reset(reinterpret_cast( - fn(reinterpret_cast( - (*polys_)[i].coefficients().coefficients().data()), - (*polys_)[i].Degree(), - reinterpret_cast(&omega), exponents[i], - &duration_in_us))); + ret.reset(c::base::native_cast(fn( + c::base::c_cast((*polys_)[i].coefficients().coefficients().data()), + (*polys_)[i].Degree(), c::base::c_cast(&omega), exponents[i], + &duration_in_us))); std::vector res_vec(ret.get(), ret.get() + (*polys_)[i].Degree()); results->emplace_back(std::move(res_vec)); } diff --git a/benchmark/msm/msm_benchmark.cc b/benchmark/msm/msm_benchmark.cc index bf002ecd4..a35d60d47 100644 --- a/benchmark/msm/msm_benchmark.cc +++ b/benchmark/msm/msm_benchmark.cc @@ -6,6 +6,7 @@ #include "benchmark/msm/simple_msm_benchmark_reporter.h" // clang-format on #include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h" #include "tachyon/c/math/elliptic_curves/bn/bn254/msm.h" namespace tachyon { diff --git a/benchmark/msm/msm_benchmark_gpu.cc b/benchmark/msm/msm_benchmark_gpu.cc index d5c9c51ff..2f163c6ff 100644 --- a/benchmark/msm/msm_benchmark_gpu.cc +++ b/benchmark/msm/msm_benchmark_gpu.cc @@ -7,6 +7,7 @@ #include "benchmark/msm/simple_msm_benchmark_reporter.h" // clang-format on #include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h" #include "tachyon/c/math/elliptic_curves/bn/bn254/msm.h" #include "tachyon/c/math/elliptic_curves/bn/bn254/msm_gpu.h" #include "tachyon/math/elliptic_curves/msm/test/variable_base_msm_test_set.h" diff --git a/benchmark/msm/msm_runner.h b/benchmark/msm/msm_runner.h index 7f006b37e..e9ccd70d0 100644 --- a/benchmark/msm/msm_runner.h +++ b/benchmark/msm/msm_runner.h @@ -10,6 +10,7 @@ #include "benchmark/msm/simple_msm_benchmark_reporter.h" // clang-format on #include "tachyon/base/time/time.h" +#include "tachyon/c/base/type_traits_forward.h" #include "tachyon/c/math/elliptic_curves/point_traits_forward.h" #include "tachyon/math/base/semigroups.h" @@ -47,11 +48,10 @@ class MSMRunner { for (size_t i = 0; i < point_nums.size(); ++i) { base::TimeTicks now = base::TimeTicks::Now(); std::unique_ptr ret; - ret.reset(fn(msm, reinterpret_cast(bases_->data()), - reinterpret_cast(scalars_->data()), - point_nums[i])); + ret.reset(fn(msm, c::base::c_cast(bases_->data()), + c::base::c_cast(scalars_->data()), point_nums[i])); reporter_->AddTime(i, (base::TimeTicks::Now() - now).InSecondsF()); - results->push_back(*reinterpret_cast(ret.get())); + results->push_back(*c::base::native_cast(ret.get())); } } @@ -61,11 +61,11 @@ class MSMRunner { for (size_t i = 0; i < point_nums.size(); ++i) { std::unique_ptr ret; uint64_t duration_in_us; - ret.reset(fn(reinterpret_cast(bases_->data()), - reinterpret_cast(scalars_->data()), - point_nums[i], &duration_in_us)); + ret.reset(fn(c::base::c_cast(bases_->data()), + c::base::c_cast(scalars_->data()), point_nums[i], + &duration_in_us)); reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF()); - results->push_back(*reinterpret_cast(ret.get())); + results->push_back(*c::base::native_cast(ret.get())); } } diff --git a/benchmark/poseidon/poseidon_benchmark_runner.h b/benchmark/poseidon/poseidon_benchmark_runner.h index 7088a139f..f11edfe91 100644 --- a/benchmark/poseidon/poseidon_benchmark_runner.h +++ b/benchmark/poseidon/poseidon_benchmark_runner.h @@ -55,7 +55,7 @@ class PoseidonBenchmarkRunner { ret.reset(fn(&duration_in_us)); reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF()); } - return *reinterpret_cast(ret.get()); + return *c::base::native_cast(ret.get()); } private: diff --git a/benchmark/poseidon2/poseidon2_benchmark_runner.h b/benchmark/poseidon2/poseidon2_benchmark_runner.h index b8be35276..3ee7ef7cb 100644 --- a/benchmark/poseidon2/poseidon2_benchmark_runner.h +++ b/benchmark/poseidon2/poseidon2_benchmark_runner.h @@ -56,7 +56,7 @@ class PoseidonBenchmarkRunner { ret.reset(fn(&duration_in_us)); reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF()); } - return *reinterpret_cast(ret.get()); + return *c::base::native_cast(ret.get()); } private: diff --git a/tachyon/c/math/elliptic_curves/generator/msm.cc.tpl b/tachyon/c/math/elliptic_curves/generator/msm.cc.tpl index 02994e576..36e3caf6a 100644 --- a/tachyon/c/math/elliptic_curves/generator/msm.cc.tpl +++ b/tachyon/c/math/elliptic_curves/generator/msm.cc.tpl @@ -1,6 +1,7 @@ // clang-format off #include "tachyon/c/math/elliptic_curves/%{header_dir_name}/g1_point_traits.h" #include "tachyon/c/math/elliptic_curves/%{header_dir_name}/g1_point_type_traits.h" +#include "tachyon/c/math/elliptic_curves/%{header_dir_name}/fr_type_traits.h" #include "tachyon/c/math/elliptic_curves/msm/msm.h" #include "tachyon/math/elliptic_curves/%{header_dir_name}/g1.h" diff --git a/tachyon/c/math/elliptic_curves/msm/msm_gpu_replay.cc b/tachyon/c/math/elliptic_curves/msm/msm_gpu_replay.cc index 4cd9d6dea..44325655d 100644 --- a/tachyon/c/math/elliptic_curves/msm/msm_gpu_replay.cc +++ b/tachyon/c/math/elliptic_curves/msm/msm_gpu_replay.cc @@ -97,10 +97,9 @@ int RealMain(int argc, char** argv) { base::TimeTicks now = base::TimeTicks::Now(); std::unique_ptr ret( - tachyon_bn254_g1_affine_msm_gpu( - msm, reinterpret_cast(bases.data()), - reinterpret_cast(scalars.data()), - scalars.size())); + tachyon_bn254_g1_affine_msm_gpu(msm, c::base::c_cast(bases.data()), + c::base::c_cast(scalars.data()), + scalars.size())); std::cout << (base::TimeTicks::Now() - now) << std::endl; std::cout << c::base::native_cast(*ret).ToAffine().ToHexString() << std::endl; diff --git a/tachyon/c/math/elliptic_curves/msm/msm_gpu_unittest.cc b/tachyon/c/math/elliptic_curves/msm/msm_gpu_unittest.cc index a102ccfaa..eb40de0e1 100644 --- a/tachyon/c/math/elliptic_curves/msm/msm_gpu_unittest.cc +++ b/tachyon/c/math/elliptic_curves/msm/msm_gpu_unittest.cc @@ -49,8 +49,7 @@ TEST_P(MSMGpuTest, MSMPoint2) { return Point2(t.bases[i].x(), t.bases[i].y()); }); ret.reset(tachyon_bn254_g1_point2_msm_gpu( - msm, reinterpret_cast(bases.data()), - reinterpret_cast(t.scalars.data()), + msm, c::base::c_cast(bases.data()), c::base::c_cast(t.scalars.data()), t.scalars.size())); EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian()); } @@ -66,8 +65,7 @@ TEST_P(MSMGpuTest, MSMG1Affine) { this->test_sets_) { std::unique_ptr ret; ret.reset(tachyon_bn254_g1_affine_msm_gpu( - msm, reinterpret_cast(t.bases.data()), - reinterpret_cast(t.scalars.data()), + msm, c::base::c_cast(t.bases.data()), c::base::c_cast(t.scalars.data()), t.scalars.size())); EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian()); } diff --git a/tachyon/c/math/elliptic_curves/msm/msm_input_provider.h b/tachyon/c/math/elliptic_curves/msm/msm_input_provider.h index fe62a891f..b43a29d0f 100644 --- a/tachyon/c/math/elliptic_curves/msm/msm_input_provider.h +++ b/tachyon/c/math/elliptic_curves/msm/msm_input_provider.h @@ -7,6 +7,7 @@ #include "absl/types/span.h" #include "tachyon/base/openmp_util.h" +#include "tachyon/c/base/type_traits_forward.h" #include "tachyon/c/math/elliptic_curves/point_traits_forward.h" #include "tachyon/math/geometry/point2.h" @@ -38,8 +39,7 @@ class MSMInputProvider { OPENMP_PARALLEL_FOR(size_t i = 0; i < aligned_size; ++i) { if (i < size) { bases_owned_[i] = reinterpret_cast(bases_in)[i]; - scalars_owned_[i] = - reinterpret_cast(scalars_in)[i]; + scalars_owned_[i] = base::native_cast(scalars_in)[i]; } else { bases_owned_[i] = AffinePoint::Zero(); scalars_owned_[i] = ScalarField::Zero(); @@ -50,8 +50,7 @@ class MSMInputProvider { } else { bases_ = absl::MakeConstSpan( reinterpret_cast(bases_in), size); - scalars_ = absl::MakeConstSpan( - reinterpret_cast(scalars_in), size); + scalars_ = absl::MakeConstSpan(base::native_cast(scalars_in), size); } } diff --git a/tachyon/c/math/elliptic_curves/msm/msm_unittest.cc b/tachyon/c/math/elliptic_curves/msm/msm_unittest.cc index 2277f57ef..d3ebd98eb 100644 --- a/tachyon/c/math/elliptic_curves/msm/msm_unittest.cc +++ b/tachyon/c/math/elliptic_curves/msm/msm_unittest.cc @@ -43,10 +43,9 @@ TEST_F(MSMTest, MSMPoint2) { base::CreateVector(t.bases.size(), [&t](size_t i) { return Point2(t.bases[i].x(), t.bases[i].y()); }); - ret.reset(tachyon_bn254_g1_point2_msm( - msm_, reinterpret_cast(bases.data()), - reinterpret_cast(t.scalars.data()), - t.scalars.size())); + ret.reset(tachyon_bn254_g1_point2_msm(msm_, c::base::c_cast(bases.data()), + c::base::c_cast(t.scalars.data()), + t.scalars.size())); EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian()); } } @@ -54,10 +53,9 @@ TEST_F(MSMTest, MSMPoint2) { TEST_F(MSMTest, MSMG1Affine) { for (const VariableBaseMSMTestSet& t : test_sets_) { std::unique_ptr ret; - ret.reset(tachyon_bn254_g1_affine_msm( - msm_, reinterpret_cast(t.bases.data()), - reinterpret_cast(t.scalars.data()), - t.scalars.size())); + ret.reset(tachyon_bn254_g1_affine_msm(msm_, c::base::c_cast(t.bases.data()), + c::base::c_cast(t.scalars.data()), + t.scalars.size())); EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian()); } } diff --git a/tachyon/c/math/polynomials/univariate/bn254_univariate_evaluations.cc b/tachyon/c/math/polynomials/univariate/bn254_univariate_evaluations.cc index bd155e29d..1605f5dbf 100644 --- a/tachyon/c/math/polynomials/univariate/bn254_univariate_evaluations.cc +++ b/tachyon/c/math/polynomials/univariate/bn254_univariate_evaluations.cc @@ -1,5 +1,6 @@ #include "tachyon/c/math/polynomials/univariate/bn254_univariate_evaluations.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h" #include "tachyon/c/math/polynomials/constants.h" #include "tachyon/math/elliptic_curves/bn/bn254/fr.h" #include "tachyon/math/polynomials/univariate/univariate_evaluations.h" @@ -35,5 +36,5 @@ void tachyon_bn254_univariate_evaluations_set_value( const tachyon_bn254_fr* value) { // NOTE(chokobole): Boundary check is the responsibility of API callers. reinterpret_cast(*evals).at(i) = - reinterpret_cast(*value); + tachyon::c::base::native_cast(*value); } diff --git a/tachyon/c/math/polynomials/univariate/bn254_univariate_rational_evaluations.cc b/tachyon/c/math/polynomials/univariate/bn254_univariate_rational_evaluations.cc index 5f90dd832..9ff98777e 100644 --- a/tachyon/c/math/polynomials/univariate/bn254_univariate_rational_evaluations.cc +++ b/tachyon/c/math/polynomials/univariate/bn254_univariate_rational_evaluations.cc @@ -4,6 +4,7 @@ #include #include "tachyon/base/logging.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h" #include "tachyon/c/math/polynomials/constants.h" #include "tachyon/math/base/rational_field.h" #include "tachyon/math/elliptic_curves/bn/bn254/fr.h" @@ -52,7 +53,7 @@ void tachyon_bn254_univariate_rational_evaluations_set_trivial( const tachyon_bn254_fr* numerator) { // NOTE(chokobole): Boundary check is the responsibility of API callers. reinterpret_cast(*evals).at(i) = - RationalField(reinterpret_cast(*numerator)); + RationalField(tachyon::c::base::native_cast(*numerator)); } void tachyon_bn254_univariate_rational_evaluations_set_rational( @@ -60,8 +61,8 @@ void tachyon_bn254_univariate_rational_evaluations_set_rational( const tachyon_bn254_fr* numerator, const tachyon_bn254_fr* denominator) { // NOTE(chokobole): Boundary check is the responsibility of API callers. reinterpret_cast(*evals).at(i) = { - reinterpret_cast(*numerator), - reinterpret_cast(*denominator), + tachyon::c::base::native_cast(*numerator), + tachyon::c::base::native_cast(*denominator), }; } diff --git a/tachyon/c/zk/plonk/halo2/BUILD.bazel b/tachyon/c/zk/plonk/halo2/BUILD.bazel index f41e17556..a9dd6508c 100644 --- a/tachyon/c/zk/plonk/halo2/BUILD.bazel +++ b/tachyon/c/zk/plonk/halo2/BUILD.bazel @@ -260,6 +260,7 @@ tachyon_cc_library( deps = [ ":prover_impl_base", "//tachyon/base:logging", + "//tachyon/c/base:type_traits_forward", "//tachyon/c/math/elliptic_curves:point_traits_forward", "//tachyon/math/elliptic_curves/msm:variable_base_msm", "@com_google_absl//absl/types:span", diff --git a/tachyon/c/zk/plonk/halo2/bn254_argument_data.cc b/tachyon/c/zk/plonk/halo2/bn254_argument_data.cc index 19abbff42..ecee4c263 100644 --- a/tachyon/c/zk/plonk/halo2/bn254_argument_data.cc +++ b/tachyon/c/zk/plonk/halo2/bn254_argument_data.cc @@ -2,6 +2,7 @@ #include +#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h" #include "tachyon/c/math/polynomials/constants.h" #include "tachyon/math/elliptic_curves/bn/bn254/fr.h" #include "tachyon/math/polynomials/univariate/univariate_evaluations.h" @@ -93,5 +94,5 @@ void tachyon_halo2_bn254_argument_data_reserve_challenges( void tachyon_halo2_bn254_argument_data_add_challenge( tachyon_halo2_bn254_argument_data* data, const tachyon_bn254_fr* value) { reinterpret_cast(data)->challenges().push_back( - reinterpret_cast(*value)); + tachyon::c::base::native_cast(*value)); } diff --git a/tachyon/c/zk/plonk/halo2/bn254_gwc_prover.cc b/tachyon/c/zk/plonk/halo2/bn254_gwc_prover.cc index b3e116b34..19ede9159 100644 --- a/tachyon/c/zk/plonk/halo2/bn254_gwc_prover.cc +++ b/tachyon/c/zk/plonk/halo2/bn254_gwc_prover.cc @@ -8,6 +8,7 @@ #include "tachyon/base/logging.h" #include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h" #include "tachyon/c/zk/plonk/halo2/bn254_gwc_pcs.h" #include "tachyon/c/zk/plonk/halo2/bn254_ls.h" #include "tachyon/c/zk/plonk/halo2/bn254_transcript.h" diff --git a/tachyon/c/zk/plonk/halo2/bn254_shplonk_prover.cc b/tachyon/c/zk/plonk/halo2/bn254_shplonk_prover.cc index e185095cc..08561755b 100644 --- a/tachyon/c/zk/plonk/halo2/bn254_shplonk_prover.cc +++ b/tachyon/c/zk/plonk/halo2/bn254_shplonk_prover.cc @@ -8,6 +8,7 @@ #include "tachyon/base/logging.h" #include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h" +#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h" #include "tachyon/c/zk/plonk/halo2/bn254_ls.h" #include "tachyon/c/zk/plonk/halo2/bn254_shplonk_pcs.h" #include "tachyon/c/zk/plonk/halo2/bn254_transcript.h" diff --git a/tachyon/c/zk/plonk/halo2/kzg_family_prover_impl.h b/tachyon/c/zk/plonk/halo2/kzg_family_prover_impl.h index 03a11ea56..b8c18950b 100644 --- a/tachyon/c/zk/plonk/halo2/kzg_family_prover_impl.h +++ b/tachyon/c/zk/plonk/halo2/kzg_family_prover_impl.h @@ -8,6 +8,7 @@ #include "absl/types/span.h" #include "tachyon/base/logging.h" +#include "tachyon/c/base/type_traits_forward.h" #include "tachyon/c/math/elliptic_curves/point_traits_forward.h" #include "tachyon/c/zk/plonk/halo2/prover_impl_base.h" #include "tachyon/math/elliptic_curves/msm/variable_base_msm.h" @@ -46,8 +47,7 @@ class KZGFamilyProverImpl : public ProverImplBase { absl::Span bases_span( bases.data(), std::min(bases.size(), scalars.size())); CHECK(msm.Run(bases_span, scalars, &bucket)); - JacobianPoint* ret = new JacobianPoint(bucket.ToJacobian()); - return reinterpret_cast(ret); + return base::c_cast(new JacobianPoint(bucket.ToJacobian())); } };