Skip to content

Commit

Permalink
refac: use c::base::xxx_cast() for casting fields and points
Browse files Browse the repository at this point in the history
  • Loading branch information
chokobole committed Jun 18, 2024
1 parent 828830a commit cb59b42
Show file tree
Hide file tree
Showing 18 changed files with 49 additions and 49 deletions.
21 changes: 9 additions & 12 deletions benchmark/fft/fft_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tachyon/base/functional/functor_traits.h"
#include "tachyon/base/time/time.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/c/math/polynomials/univariate/bn254_univariate_evaluation_domain.h"

namespace tachyon {
Expand Down Expand Up @@ -74,25 +75,21 @@ class FFTRunner {
std::unique_ptr<F> ret;
if constexpr (std::is_same_v<PolyOrEvals, typename Domain::Evals>) {
const F omega_inv = domains_[i]->group_gen_inv();
ret.reset(reinterpret_cast<F*>(
fn(reinterpret_cast<const tachyon_bn254_fr*>(
(*polys_)[i].evaluations().data()),
(*polys_)[i].Degree(),
reinterpret_cast<const tachyon_bn254_fr*>(&omega_inv),
exponents[i], &duration_in_us)));
ret.reset(c::base::native_cast(
fn(c::base::c_cast((*polys_)[i].evaluations().data()),
(*polys_)[i].Degree(), c::base::c_cast(&omega_inv), exponents[i],
&duration_in_us)));
std::vector<F> res_vec(ret.get(), ret.get() + (*polys_)[i].Degree());
results->emplace_back(
typename RetPoly::Coefficients(std::move(res_vec)));
// NOLINTNEXTLINE(readability/braces)
} else if constexpr (std::is_same_v<PolyOrEvals,
typename Domain::DensePoly>) {
const F omega = domains_[i]->group_gen();
ret.reset(reinterpret_cast<F*>(
fn(reinterpret_cast<const tachyon_bn254_fr*>(
(*polys_)[i].coefficients().coefficients().data()),
(*polys_)[i].Degree(),
reinterpret_cast<const tachyon_bn254_fr*>(&omega), exponents[i],
&duration_in_us)));
ret.reset(c::base::native_cast(fn(
c::base::c_cast((*polys_)[i].coefficients().coefficients().data()),
(*polys_)[i].Degree(), c::base::c_cast(&omega), exponents[i],
&duration_in_us)));
std::vector<F> res_vec(ret.get(), ret.get() + (*polys_)[i].Degree());
results->emplace_back(std::move(res_vec));
}
Expand Down
1 change: 1 addition & 0 deletions benchmark/msm/msm_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "benchmark/msm/simple_msm_benchmark_reporter.h"
// clang-format on
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/msm.h"

namespace tachyon {
Expand Down
1 change: 1 addition & 0 deletions benchmark/msm/msm_benchmark_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "benchmark/msm/simple_msm_benchmark_reporter.h"
// clang-format on
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/msm.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/msm_gpu.h"
#include "tachyon/math/elliptic_curves/msm/test/variable_base_msm_test_set.h"
Expand Down
16 changes: 8 additions & 8 deletions benchmark/msm/msm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "benchmark/msm/simple_msm_benchmark_reporter.h"
// clang-format on
#include "tachyon/base/time/time.h"
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/point_traits_forward.h"
#include "tachyon/math/base/semigroups.h"

Expand Down Expand Up @@ -47,11 +48,10 @@ class MSMRunner {
for (size_t i = 0; i < point_nums.size(); ++i) {
base::TimeTicks now = base::TimeTicks::Now();
std::unique_ptr<CRetPoint> ret;
ret.reset(fn(msm, reinterpret_cast<const CPoint*>(bases_->data()),
reinterpret_cast<const CScalarField*>(scalars_->data()),
point_nums[i]));
ret.reset(fn(msm, c::base::c_cast(bases_->data()),
c::base::c_cast(scalars_->data()), point_nums[i]));
reporter_->AddTime(i, (base::TimeTicks::Now() - now).InSecondsF());
results->push_back(*reinterpret_cast<RetPoint*>(ret.get()));
results->push_back(*c::base::native_cast(ret.get()));
}
}

Expand All @@ -61,11 +61,11 @@ class MSMRunner {
for (size_t i = 0; i < point_nums.size(); ++i) {
std::unique_ptr<CRetPoint> ret;
uint64_t duration_in_us;
ret.reset(fn(reinterpret_cast<const CPoint*>(bases_->data()),
reinterpret_cast<const CScalarField*>(scalars_->data()),
point_nums[i], &duration_in_us));
ret.reset(fn(c::base::c_cast(bases_->data()),
c::base::c_cast(scalars_->data()), point_nums[i],
&duration_in_us));
reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF());
results->push_back(*reinterpret_cast<RetPoint*>(ret.get()));
results->push_back(*c::base::native_cast(ret.get()));
}
}

Expand Down
2 changes: 1 addition & 1 deletion benchmark/poseidon/poseidon_benchmark_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class PoseidonBenchmarkRunner {
ret.reset(fn(&duration_in_us));
reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF());
}
return *reinterpret_cast<Field*>(ret.get());
return *c::base::native_cast(ret.get());
}

private:
Expand Down
2 changes: 1 addition & 1 deletion benchmark/poseidon2/poseidon2_benchmark_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class PoseidonBenchmarkRunner {
ret.reset(fn(&duration_in_us));
reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF());
}
return *reinterpret_cast<Field*>(ret.get());
return *c::base::native_cast(ret.get());
}

private:
Expand Down
1 change: 1 addition & 0 deletions tachyon/c/math/elliptic_curves/generator/msm.cc.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// clang-format off
#include "tachyon/c/math/elliptic_curves/%{header_dir_name}/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/%{header_dir_name}/g1_point_type_traits.h"
#include "tachyon/c/math/elliptic_curves/%{header_dir_name}/fr_type_traits.h"
#include "tachyon/c/math/elliptic_curves/msm/msm.h"
#include "tachyon/math/elliptic_curves/%{header_dir_name}/g1.h"

Expand Down
7 changes: 3 additions & 4 deletions tachyon/c/math/elliptic_curves/msm/msm_gpu_replay.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ int RealMain(int argc, char** argv) {

base::TimeTicks now = base::TimeTicks::Now();
std::unique_ptr<tachyon_bn254_g1_jacobian> ret(
tachyon_bn254_g1_affine_msm_gpu(
msm, reinterpret_cast<const tachyon_bn254_g1_affine*>(bases.data()),
reinterpret_cast<const tachyon_bn254_fr*>(scalars.data()),
scalars.size()));
tachyon_bn254_g1_affine_msm_gpu(msm, c::base::c_cast(bases.data()),
c::base::c_cast(scalars.data()),
scalars.size()));
std::cout << (base::TimeTicks::Now() - now) << std::endl;
std::cout << c::base::native_cast(*ret).ToAffine().ToHexString()
<< std::endl;
Expand Down
6 changes: 2 additions & 4 deletions tachyon/c/math/elliptic_curves/msm/msm_gpu_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ TEST_P(MSMGpuTest, MSMPoint2) {
return Point2<bn254::Fq>(t.bases[i].x(), t.bases[i].y());
});
ret.reset(tachyon_bn254_g1_point2_msm_gpu(
msm, reinterpret_cast<const tachyon_bn254_g1_point2*>(bases.data()),
reinterpret_cast<const tachyon_bn254_fr*>(t.scalars.data()),
msm, c::base::c_cast(bases.data()), c::base::c_cast(t.scalars.data()),
t.scalars.size()));
EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian());
}
Expand All @@ -66,8 +65,7 @@ TEST_P(MSMGpuTest, MSMG1Affine) {
this->test_sets_) {
std::unique_ptr<tachyon_bn254_g1_jacobian> ret;
ret.reset(tachyon_bn254_g1_affine_msm_gpu(
msm, reinterpret_cast<const tachyon_bn254_g1_affine*>(t.bases.data()),
reinterpret_cast<const tachyon_bn254_fr*>(t.scalars.data()),
msm, c::base::c_cast(t.bases.data()), c::base::c_cast(t.scalars.data()),
t.scalars.size()));
EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian());
}
Expand Down
7 changes: 3 additions & 4 deletions tachyon/c/math/elliptic_curves/msm/msm_input_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "absl/types/span.h"

#include "tachyon/base/openmp_util.h"
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/point_traits_forward.h"
#include "tachyon/math/geometry/point2.h"

Expand Down Expand Up @@ -38,8 +39,7 @@ class MSMInputProvider {
OPENMP_PARALLEL_FOR(size_t i = 0; i < aligned_size; ++i) {
if (i < size) {
bases_owned_[i] = reinterpret_cast<const AffinePoint*>(bases_in)[i];
scalars_owned_[i] =
reinterpret_cast<const ScalarField*>(scalars_in)[i];
scalars_owned_[i] = base::native_cast(scalars_in)[i];
} else {
bases_owned_[i] = AffinePoint::Zero();
scalars_owned_[i] = ScalarField::Zero();
Expand All @@ -50,8 +50,7 @@ class MSMInputProvider {
} else {
bases_ = absl::MakeConstSpan(
reinterpret_cast<const AffinePoint*>(bases_in), size);
scalars_ = absl::MakeConstSpan(
reinterpret_cast<const ScalarField*>(scalars_in), size);
scalars_ = absl::MakeConstSpan(base::native_cast(scalars_in), size);
}
}

Expand Down
14 changes: 6 additions & 8 deletions tachyon/c/math/elliptic_curves/msm/msm_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,19 @@ TEST_F(MSMTest, MSMPoint2) {
base::CreateVector(t.bases.size(), [&t](size_t i) {
return Point2<bn254::Fq>(t.bases[i].x(), t.bases[i].y());
});
ret.reset(tachyon_bn254_g1_point2_msm(
msm_, reinterpret_cast<const tachyon_bn254_g1_point2*>(bases.data()),
reinterpret_cast<const tachyon_bn254_fr*>(t.scalars.data()),
t.scalars.size()));
ret.reset(tachyon_bn254_g1_point2_msm(msm_, c::base::c_cast(bases.data()),
c::base::c_cast(t.scalars.data()),
t.scalars.size()));
EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian());
}
}

TEST_F(MSMTest, MSMG1Affine) {
for (const VariableBaseMSMTestSet<bn254::G1AffinePoint>& t : test_sets_) {
std::unique_ptr<tachyon_bn254_g1_jacobian> ret;
ret.reset(tachyon_bn254_g1_affine_msm(
msm_, reinterpret_cast<const tachyon_bn254_g1_affine*>(t.bases.data()),
reinterpret_cast<const tachyon_bn254_fr*>(t.scalars.data()),
t.scalars.size()));
ret.reset(tachyon_bn254_g1_affine_msm(msm_, c::base::c_cast(t.bases.data()),
c::base::c_cast(t.scalars.data()),
t.scalars.size()));
EXPECT_EQ(c::base::native_cast(*ret), t.answer.ToJacobian());
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "tachyon/c/math/polynomials/univariate/bn254_univariate_evaluations.h"

#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/c/math/polynomials/constants.h"
#include "tachyon/math/elliptic_curves/bn/bn254/fr.h"
#include "tachyon/math/polynomials/univariate/univariate_evaluations.h"
Expand Down Expand Up @@ -35,5 +36,5 @@ void tachyon_bn254_univariate_evaluations_set_value(
const tachyon_bn254_fr* value) {
// NOTE(chokobole): Boundary check is the responsibility of API callers.
reinterpret_cast<Evals&>(*evals).at(i) =
reinterpret_cast<const bn254::Fr&>(*value);
tachyon::c::base::native_cast(*value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <vector>

#include "tachyon/base/logging.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/c/math/polynomials/constants.h"
#include "tachyon/math/base/rational_field.h"
#include "tachyon/math/elliptic_curves/bn/bn254/fr.h"
Expand Down Expand Up @@ -52,16 +53,16 @@ void tachyon_bn254_univariate_rational_evaluations_set_trivial(
const tachyon_bn254_fr* numerator) {
// NOTE(chokobole): Boundary check is the responsibility of API callers.
reinterpret_cast<RationalEvals&>(*evals).at(i) =
RationalField<bn254::Fr>(reinterpret_cast<const bn254::Fr&>(*numerator));
RationalField<bn254::Fr>(tachyon::c::base::native_cast(*numerator));
}

void tachyon_bn254_univariate_rational_evaluations_set_rational(
tachyon_bn254_univariate_rational_evaluations* evals, size_t i,
const tachyon_bn254_fr* numerator, const tachyon_bn254_fr* denominator) {
// NOTE(chokobole): Boundary check is the responsibility of API callers.
reinterpret_cast<RationalEvals&>(*evals).at(i) = {
reinterpret_cast<const bn254::Fr&>(*numerator),
reinterpret_cast<const bn254::Fr&>(*denominator),
tachyon::c::base::native_cast(*numerator),
tachyon::c::base::native_cast(*denominator),
};
}

Expand Down
1 change: 1 addition & 0 deletions tachyon/c/zk/plonk/halo2/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ tachyon_cc_library(
deps = [
":prover_impl_base",
"//tachyon/base:logging",
"//tachyon/c/base:type_traits_forward",
"//tachyon/c/math/elliptic_curves:point_traits_forward",
"//tachyon/math/elliptic_curves/msm:variable_base_msm",
"@com_google_absl//absl/types:span",
Expand Down
3 changes: 2 additions & 1 deletion tachyon/c/zk/plonk/halo2/bn254_argument_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <utility>

#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/c/math/polynomials/constants.h"
#include "tachyon/math/elliptic_curves/bn/bn254/fr.h"
#include "tachyon/math/polynomials/univariate/univariate_evaluations.h"
Expand Down Expand Up @@ -93,5 +94,5 @@ void tachyon_halo2_bn254_argument_data_reserve_challenges(
void tachyon_halo2_bn254_argument_data_add_challenge(
tachyon_halo2_bn254_argument_data* data, const tachyon_bn254_fr* value) {
reinterpret_cast<Data*>(data)->challenges().push_back(
reinterpret_cast<const math::bn254::Fr&>(*value));
tachyon::c::base::native_cast(*value));
}
1 change: 1 addition & 0 deletions tachyon/c/zk/plonk/halo2/bn254_gwc_prover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "tachyon/base/logging.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h"
#include "tachyon/c/zk/plonk/halo2/bn254_gwc_pcs.h"
#include "tachyon/c/zk/plonk/halo2/bn254_ls.h"
#include "tachyon/c/zk/plonk/halo2/bn254_transcript.h"
Expand Down
1 change: 1 addition & 0 deletions tachyon/c/zk/plonk/halo2/bn254_shplonk_prover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "tachyon/base/logging.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h"
#include "tachyon/c/zk/plonk/halo2/bn254_ls.h"
#include "tachyon/c/zk/plonk/halo2/bn254_shplonk_pcs.h"
#include "tachyon/c/zk/plonk/halo2/bn254_transcript.h"
Expand Down
4 changes: 2 additions & 2 deletions tachyon/c/zk/plonk/halo2/kzg_family_prover_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "absl/types/span.h"

#include "tachyon/base/logging.h"
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/point_traits_forward.h"
#include "tachyon/c/zk/plonk/halo2/prover_impl_base.h"
#include "tachyon/math/elliptic_curves/msm/variable_base_msm.h"
Expand Down Expand Up @@ -46,8 +47,7 @@ class KZGFamilyProverImpl : public ProverImplBase<PCS, LS> {
absl::Span<const AffinePoint> bases_span(
bases.data(), std::min(bases.size(), scalars.size()));
CHECK(msm.Run(bases_span, scalars, &bucket));
JacobianPoint* ret = new JacobianPoint(bucket.ToJacobian());
return reinterpret_cast<CJacobianPoint*>(ret);
return base::c_cast(new JacobianPoint(bucket.ToJacobian()));
}
};

Expand Down

0 comments on commit cb59b42

Please sign in to comment.