Skip to content

Commit

Permalink
Merge pull request kroma-network#438 from kroma-network/apply-c-type-…
Browse files Browse the repository at this point in the history
…traits

refac: apply c type traits to point types
  • Loading branch information
chokobole authored Jun 18, 2024
2 parents c9e6f93 + cc964ca commit 2e2c671
Show file tree
Hide file tree
Showing 57 changed files with 631 additions and 633 deletions.
21 changes: 9 additions & 12 deletions benchmark/fft/fft_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tachyon/base/functional/functor_traits.h"
#include "tachyon/base/time/time.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/c/math/polynomials/univariate/bn254_univariate_evaluation_domain.h"

namespace tachyon {
Expand Down Expand Up @@ -74,25 +75,21 @@ class FFTRunner {
std::unique_ptr<F> ret;
if constexpr (std::is_same_v<PolyOrEvals, typename Domain::Evals>) {
const F omega_inv = domains_[i]->group_gen_inv();
ret.reset(reinterpret_cast<F*>(
fn(reinterpret_cast<const tachyon_bn254_fr*>(
(*polys_)[i].evaluations().data()),
(*polys_)[i].Degree(),
reinterpret_cast<const tachyon_bn254_fr*>(&omega_inv),
exponents[i], &duration_in_us)));
ret.reset(c::base::native_cast(
fn(c::base::c_cast((*polys_)[i].evaluations().data()),
(*polys_)[i].Degree(), c::base::c_cast(&omega_inv), exponents[i],
&duration_in_us)));
std::vector<F> res_vec(ret.get(), ret.get() + (*polys_)[i].Degree());
results->emplace_back(
typename RetPoly::Coefficients(std::move(res_vec)));
// NOLINTNEXTLINE(readability/braces)
} else if constexpr (std::is_same_v<PolyOrEvals,
typename Domain::DensePoly>) {
const F omega = domains_[i]->group_gen();
ret.reset(reinterpret_cast<F*>(
fn(reinterpret_cast<const tachyon_bn254_fr*>(
(*polys_)[i].coefficients().coefficients().data()),
(*polys_)[i].Degree(),
reinterpret_cast<const tachyon_bn254_fr*>(&omega), exponents[i],
&duration_in_us)));
ret.reset(c::base::native_cast(fn(
c::base::c_cast((*polys_)[i].coefficients().coefficients().data()),
(*polys_)[i].Degree(), c::base::c_cast(&omega), exponents[i],
&duration_in_us)));
std::vector<F> res_vec(ret.get(), ret.get() + (*polys_)[i].Degree());
results->emplace_back(std::move(res_vec));
}
Expand Down
1 change: 1 addition & 0 deletions benchmark/msm/msm_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "benchmark/msm/simple_msm_benchmark_reporter.h"
// clang-format on
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/msm.h"

namespace tachyon {
Expand Down
1 change: 1 addition & 0 deletions benchmark/msm/msm_benchmark_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "benchmark/msm/simple_msm_benchmark_reporter.h"
// clang-format on
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/g1_point_type_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/msm.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/msm_gpu.h"
#include "tachyon/math/elliptic_curves/msm/test/variable_base_msm_test_set.h"
Expand Down
16 changes: 8 additions & 8 deletions benchmark/msm/msm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "benchmark/msm/simple_msm_benchmark_reporter.h"
// clang-format on
#include "tachyon/base/time/time.h"
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/point_traits_forward.h"
#include "tachyon/math/base/semigroups.h"

Expand Down Expand Up @@ -47,11 +48,10 @@ class MSMRunner {
for (size_t i = 0; i < point_nums.size(); ++i) {
base::TimeTicks now = base::TimeTicks::Now();
std::unique_ptr<CRetPoint> ret;
ret.reset(fn(msm, reinterpret_cast<const CPoint*>(bases_->data()),
reinterpret_cast<const CScalarField*>(scalars_->data()),
point_nums[i]));
ret.reset(fn(msm, c::base::c_cast(bases_->data()),
c::base::c_cast(scalars_->data()), point_nums[i]));
reporter_->AddTime(i, (base::TimeTicks::Now() - now).InSecondsF());
results->push_back(*reinterpret_cast<RetPoint*>(ret.get()));
results->push_back(*c::base::native_cast(ret.get()));
}
}

Expand All @@ -61,11 +61,11 @@ class MSMRunner {
for (size_t i = 0; i < point_nums.size(); ++i) {
std::unique_ptr<CRetPoint> ret;
uint64_t duration_in_us;
ret.reset(fn(reinterpret_cast<const CPoint*>(bases_->data()),
reinterpret_cast<const CScalarField*>(scalars_->data()),
point_nums[i], &duration_in_us));
ret.reset(fn(c::base::c_cast(bases_->data()),
c::base::c_cast(scalars_->data()), point_nums[i],
&duration_in_us));
reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF());
results->push_back(*reinterpret_cast<RetPoint*>(ret.get()));
results->push_back(*c::base::native_cast(ret.get()));
}
}

Expand Down
4 changes: 2 additions & 2 deletions benchmark/poseidon/poseidon_benchmark_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "tachyon/base/logging.h"
#include "tachyon/base/time/time.h"
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/crypto/hashes/sponge/poseidon/poseidon.h"

namespace tachyon {
Expand Down Expand Up @@ -55,7 +55,7 @@ class PoseidonBenchmarkRunner {
ret.reset(fn(&duration_in_us));
reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF());
}
return *reinterpret_cast<Field*>(ret.get());
return *c::base::native_cast(ret.get());
}

private:
Expand Down
4 changes: 2 additions & 2 deletions benchmark/poseidon2/poseidon2_benchmark_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "tachyon/base/logging.h"
#include "tachyon/base/time/time.h"
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_traits.h"
#include "tachyon/c/math/elliptic_curves/bn/bn254/fr_type_traits.h"
#include "tachyon/crypto/hashes/sponge/poseidon2/poseidon2.h"
#include "tachyon/crypto/hashes/sponge/poseidon2/poseidon2_horizen_external_matrix.h"

Expand Down Expand Up @@ -56,7 +56,7 @@ class PoseidonBenchmarkRunner {
ret.reset(fn(&duration_in_us));
reporter_->AddTime(i, base::Microseconds(duration_in_us).InSecondsF());
}
return *reinterpret_cast<Field*>(ret.get());
return *c::base::native_cast(ret.get());
}

private:
Expand Down
12 changes: 12 additions & 0 deletions tachyon/base/strings/string_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ std::u16string ToUpperASCII(std::u16string_view str) {
return internal::ToUpperASCIIImpl(str);
}

std::string CapitalizeASCII(std::string_view str) {
std::string ret = std::string(str);
if (!str.empty()) ret[0] = ToUpperASCII(str[0]);
return ret;
}

std::u16string CapitalizeASCII(std::u16string_view str) {
std::u16string ret = std::u16string(str);
if (!str.empty()) ret[0] = ToUpperASCII(str[0]);
return ret;
}

const std::string& EmptyString() {
static const NoDestructor<std::string> s;
return *s;
Expand Down
5 changes: 5 additions & 0 deletions tachyon/base/strings/string_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ TACHYON_EXPORT std::u16string ToLowerASCII(std::u16string_view str);
TACHYON_EXPORT std::string ToUpperASCII(std::string_view str);
TACHYON_EXPORT std::u16string ToUpperASCII(std::u16string_view str);

// Capitalize the given string. Non-ASCII bytes (or UTF-16 code units in
// `std::u16string_view`) are permitted but will be unmodified.
TACHYON_EXPORT std::string CapitalizeASCII(std::string_view str);
TACHYON_EXPORT std::u16string CapitalizeASCII(std::u16string_view str);

// Like strcasecmp for ASCII case-insensitive comparisons only. Returns:
// -1 (a < b)
// 0 (a == b)
Expand Down
5 changes: 3 additions & 2 deletions tachyon/c/math/elliptic_curves/generator/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ tachyon_cc_binary(
data = [
"ext_field.cc.tpl",
"ext_field.h.tpl",
"ext_field_traits.h.tpl",
"ext_field_type_traits.h.tpl",
"msm.cc.tpl",
"msm.h.tpl",
"msm_gpu.cc.tpl",
"msm_gpu.h.tpl",
"point.cc.tpl",
"point.h.tpl",
"point_traits.h.tpl",
"point_type_traits.h.tpl",
"prime_field.cc.tpl",
"prime_field.h.tpl",
"prime_field_traits.h.tpl",
"prime_field_type_traits.h.tpl",
],
deps = [
":generator_util",
Expand Down
55 changes: 33 additions & 22 deletions tachyon/c/math/elliptic_curves/generator/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ load("//bazel:tachyon_cc.bzl", "tachyon_cc_library", "tachyon_cuda_library")
def _generate_ec_point_impl(ctx):
prime_field_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field.h.tpl)", [ctx.attr.prime_field_hdr_tpl_path])
prime_field_src_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field.cc.tpl)", [ctx.attr.prime_field_src_tpl_path])
prime_field_traits_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field_traits.h.tpl)", [ctx.attr.prime_field_traits_hdr_tpl_path])
prime_field_type_traits_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field_type_traits.h.tpl)", [ctx.attr.prime_field_type_traits_hdr_tpl_path])
ext_field_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field.h.tpl)", [ctx.attr.ext_field_hdr_tpl_path])
ext_field_src_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field.cc.tpl)", [ctx.attr.ext_field_src_tpl_path])
ext_field_traits_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field_traits.h.tpl)", [ctx.attr.ext_field_traits_hdr_tpl_path])
ext_field_type_traits_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field_type_traits.h.tpl)", [ctx.attr.ext_field_type_traits_hdr_tpl_path])
point_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:point.h.tpl)", [ctx.attr.point_hdr_tpl_path])
point_src_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:point.cc.tpl)", [ctx.attr.point_src_tpl_path])
point_traits_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:point_traits.h.tpl)", [ctx.attr.point_traits_hdr_tpl_path])
point_type_traits_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:point_type_traits.h.tpl)", [ctx.attr.point_type_traits_hdr_tpl_path])
msm_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:msm.h.tpl)", [ctx.attr.msm_hdr_tpl_path])
msm_src_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:msm.cc.tpl)", [ctx.attr.msm_src_tpl_path])
msm_gpu_hdr_tpl_path = ctx.expand_location("$(location @kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:msm_gpu.h.tpl)", [ctx.attr.msm_gpu_hdr_tpl_path])
Expand All @@ -26,13 +27,14 @@ def _generate_ec_point_impl(ctx):
"--has_specialized_g1_msm_kernels=%s" % (ctx.attr.has_specialized_g1_msm_kernels),
"--prime_field_hdr_tpl_path=%s" % (prime_field_hdr_tpl_path),
"--prime_field_src_tpl_path=%s" % (prime_field_src_tpl_path),
"--prime_field_traits_hdr_tpl_path=%s" % (prime_field_traits_hdr_tpl_path),
"--prime_field_type_traits_hdr_tpl_path=%s" % (prime_field_type_traits_hdr_tpl_path),
"--ext_field_hdr_tpl_path=%s" % (ext_field_hdr_tpl_path),
"--ext_field_src_tpl_path=%s" % (ext_field_src_tpl_path),
"--ext_field_traits_hdr_tpl_path=%s" % (ext_field_traits_hdr_tpl_path),
"--ext_field_type_traits_hdr_tpl_path=%s" % (ext_field_type_traits_hdr_tpl_path),
"--point_hdr_tpl_path=%s" % (point_hdr_tpl_path),
"--point_src_tpl_path=%s" % (point_src_tpl_path),
"--point_traits_hdr_tpl_path=%s" % (point_traits_hdr_tpl_path),
"--point_type_traits_hdr_tpl_path=%s" % (point_type_traits_hdr_tpl_path),
"--msm_hdr_tpl_path=%s" % (msm_hdr_tpl_path),
"--msm_src_tpl_path=%s" % (msm_src_tpl_path),
"--msm_gpu_hdr_tpl_path=%s" % (msm_gpu_hdr_tpl_path),
Expand All @@ -43,13 +45,14 @@ def _generate_ec_point_impl(ctx):
inputs = [
ctx.files.prime_field_hdr_tpl_path[0],
ctx.files.prime_field_src_tpl_path[0],
ctx.files.prime_field_traits_hdr_tpl_path[0],
ctx.files.prime_field_type_traits_hdr_tpl_path[0],
ctx.files.ext_field_hdr_tpl_path[0],
ctx.files.ext_field_src_tpl_path[0],
ctx.files.ext_field_traits_hdr_tpl_path[0],
ctx.files.ext_field_type_traits_hdr_tpl_path[0],
ctx.files.point_hdr_tpl_path[0],
ctx.files.point_src_tpl_path[0],
ctx.files.point_traits_hdr_tpl_path[0],
ctx.files.point_type_traits_hdr_tpl_path[0],
ctx.files.msm_hdr_tpl_path[0],
ctx.files.msm_src_tpl_path[0],
ctx.files.msm_gpu_hdr_tpl_path[0],
Expand Down Expand Up @@ -81,9 +84,9 @@ generate_ec_point = rule(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field.cc.tpl"),
),
"prime_field_traits_hdr_tpl_path": attr.label(
"prime_field_type_traits_hdr_tpl_path": attr.label(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field_traits.h.tpl"),
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:prime_field_type_traits.h.tpl"),
),
"ext_field_hdr_tpl_path": attr.label(
allow_single_file = True,
Expand All @@ -93,9 +96,9 @@ generate_ec_point = rule(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field.cc.tpl"),
),
"ext_field_traits_hdr_tpl_path": attr.label(
"ext_field_type_traits_hdr_tpl_path": attr.label(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field_traits.h.tpl"),
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:ext_field_type_traits.h.tpl"),
),
"point_hdr_tpl_path": attr.label(
allow_single_file = True,
Expand All @@ -109,6 +112,10 @@ generate_ec_point = rule(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:point_traits.h.tpl"),
),
"point_type_traits_hdr_tpl_path": attr.label(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:point_type_traits.h.tpl"),
),
"msm_hdr_tpl_path": attr.label(
allow_single_file = True,
default = Label("@kroma_network_tachyon//tachyon/c/math/elliptic_curves/generator:msm.h.tpl"),
Expand Down Expand Up @@ -158,17 +165,19 @@ def generate_ec_points(
("gen_fq6_src", "fq6.cc", 6, 2),
("gen_fq12_hdr", "fq12.h", 12, 6),
("gen_fq12_src", "fq12.cc", 12, 6),
("gen_fq2_traits_hdr", "fq2_traits.h", 2, 1),
("gen_fq6_traits_hdr", "fq6_traits.h", 6, 2),
("gen_fq12_traits_hdr", "fq12_traits.h", 12, 6),
("gen_fq2_type_traits_hdr", "fq2_type_traits.h", 2, 1),
("gen_fq6_type_traits_hdr", "fq6_type_traits.h", 6, 2),
("gen_fq12_type_traits_hdr", "fq12_type_traits.h", 12, 6),
("gen_g1_hdr", "g1.h", 0, 0),
("gen_g1_src", "g1.cc", 0, 0),
("gen_g2_hdr", "g2.h", 0, 0),
("gen_g2_src", "g2.cc", 0, 0),
("gen_fq_traits", "fq_traits.h", 0, 0),
("gen_fr_traits", "fr_traits.h", 0, 0),
("gen_fq_type_traits", "fq_type_traits.h", 0, 0),
("gen_fr_type_traits", "fr_type_traits.h", 0, 0),
("gen_g1_point_traits", "g1_point_traits.h", 0, 0),
("gen_g2_point_traits", "g2_point_traits.h", 0, 0),
("gen_g1_point_type_traits", "g1_point_type_traits.h", 0, 0),
("gen_g2_point_type_traits", "g2_point_type_traits.h", 0, 0),
("gen_msm_hdr", "msm.h", 0, 0),
("gen_msm_src", "msm.cc", 0, 0),
("gen_msm_gpu_hdr", "msm_gpu.h", 0, 0),
Expand All @@ -189,7 +198,7 @@ def generate_ec_points(
name = "fq",
hdrs = [
"fq.h",
"fq_traits.h",
"fq_type_traits.h",
],
srcs = ["fq.cc"],
deps = g1_deps + [
Expand All @@ -202,7 +211,7 @@ def generate_ec_points(
name = "fr",
hdrs = [
"fr.h",
"fr_traits.h",
"fr_type_traits.h",
],
srcs = ["fr.cc"],
deps = g1_deps + [
Expand All @@ -215,7 +224,7 @@ def generate_ec_points(
name = "fq2",
hdrs = [
"fq2.h",
"fq2_traits.h",
"fq2_type_traits.h",
],
srcs = ["fq2.cc"],
deps = fq2_deps + [":fq"],
Expand All @@ -225,7 +234,7 @@ def generate_ec_points(
name = "fq6",
hdrs = [
"fq6.h",
"fq6_traits.h",
"fq6_type_traits.h",
],
srcs = ["fq6.cc"],
deps = fq6_deps + [":fq2"],
Expand All @@ -235,7 +244,7 @@ def generate_ec_points(
name = "fq12",
hdrs = [
"fq12.h",
"fq12_traits.h",
"fq12_type_traits.h",
],
srcs = ["fq12.cc"],
deps = fq12_deps + [":fq6"],
Expand All @@ -246,12 +255,13 @@ def generate_ec_points(
hdrs = [
"g1.h",
"g1_point_traits.h",
"g1_point_type_traits.h",
],
srcs = ["g1.cc"],
deps = [
":fq",
":fr",
"//tachyon/c/math/elliptic_curves:point_conversions",
"//tachyon/c/math/elliptic_curves:point_traits_forward",
],
)

Expand All @@ -260,12 +270,13 @@ def generate_ec_points(
hdrs = [
"g2.h",
"g2_point_traits.h",
"g2_point_type_traits.h",
],
srcs = ["g2.cc"],
deps = g2_deps + [
":fq2",
":fr",
"//tachyon/c/math/elliptic_curves:point_conversions",
"//tachyon/c/math/elliptic_curves:point_traits_forward",
],
)

Expand Down
2 changes: 1 addition & 1 deletion tachyon/c/math/elliptic_curves/generator/ext_field.cc.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// clang-format off
#include "tachyon/c/base/type_traits_forward.h"
#include "tachyon/c/math/elliptic_curves/%{header_dir_name}/fq%{degree}_traits.h"
#include "tachyon/c/math/elliptic_curves/%{header_dir_name}/fq%{degree}_type_traits.h"

tachyon_%{type}_fq%{degree} tachyon_%{type}_fq%{degree}_zero() {
using namespace tachyon::c::base;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ struct TypeTraits<tachyon_%{type}_fq%{degree}> {
using NativeType = tachyon::math::%{type}::Fq%{degree};
};

} // namespace tachyon::cc::math
} // namespace tachyon::c::base
// clang-format on
Loading

0 comments on commit 2e2c671

Please sign in to comment.