Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

U8Air Extended #115

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pil2-components/lib/std/pil/std_prod.pil
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ private function update_piop_prod(int name, int proves, int opid, expr sel, expr
}

// define constraints at the air level
on final air piop_gprod_air();
on final(-1) air piop_gprod_air();

// update values at the airgroup level
on final airgroup piop_gprod_airgroup();
on final(-1) airgroup piop_gprod_airgroup();

// at the end, check consistency of all the opids
on final proof check_opids_were_completed_prod();
on final(-1) proof check_opids_were_completed_prod();
}

// update constraints at the proof level
on final proof piop_gprod_proof();
on final(-1) proof piop_gprod_proof();
}

/**
Expand Down
196 changes: 155 additions & 41 deletions pil2-components/lib/std/pil/std_range_check.pil
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ require "std_lookup.pil";
// Moreover, having a fewer ranges makes the preprocessing faster and the prvoing less memory consuming

const int MAX_RANGE_LEN = (PRIME - 1) / 2;
const int BYTE = 2**8-1;
const int TWOBYTES = 2**16-1;
const int STD_P2_8 = 2**8;
const int STD_BYTE = STD_P2_8-1;
const int STD_TWO_BYTES = 2**16-1;

const int OPIDS[2] = [100, 101];
int last_assigned_opid = OPIDS[length(OPIDS) - 1];
const int DEFAULT_OPIDS[3] = [100, 101, 102];
int last_assigned_opid = DEFAULT_OPIDS[length(DEFAULT_OPIDS) - 1];

private function next_available_opid(): int {
last_assigned_opid++;
Expand All @@ -18,10 +19,10 @@ private function next_available_opid(): int {

private function get_opid(const int min, const int max, const int predefined): int {
if (predefined && min >= 0) {
if (max <= BYTE) {
return OPIDS[0];
} else if (max <= TWOBYTES) {
return OPIDS[1];
if (max <= STD_BYTE) {
return (min > 0 || max < STD_BYTE) ? DEFAULT_OPIDS[2] : DEFAULT_OPIDS[0];
} else if (max <= STD_TWO_BYTES) {
return DEFAULT_OPIDS[1];
}
}

Expand All @@ -41,8 +42,28 @@ airtemplate U8Air(const int N = 2**8) {

@u8air{reference: mul};

col fixed U8 = [0..BYTE];
lookup_proves(OPIDS[0], [U8], mul, PIOP_NAME_RANGE_CHECK);
col fixed U8 = [0..STD_BYTE];
lookup_proves(DEFAULT_OPIDS[0], [U8], mul, PIOP_NAME_RANGE_CHECK);
}

airtemplate U8AirExtended(const int N = 2**16) {
// The same as U8Air, but with the cartisian product of U8 with itself to allow for lookups from two to two

if (N != 2**16) {
error(`The number of rows N should be 2**16 to use the predefined range U8xU8, got N=${N} instead`);
}

// save the airgroup id and air id of the table for latter use
proof.std.u8ext.airgroup_id = AIRGROUP_ID;
proof.std.u8ext.air_id = AIR_ID;

col witness mul;

@u8airext{reference: mul};

col fixed U8_A = [0..STD_BYTE]...;
col fixed U8_B = [0:STD_P2_8..STD_BYTE:STD_P2_8];
lookup_proves(DEFAULT_OPIDS[2], [U8_A, U8_B], mul, PIOP_NAME_RANGE_CHECK);
}

airtemplate U16Air(const int N = 2**16) {
Expand All @@ -58,8 +79,8 @@ airtemplate U16Air(const int N = 2**16) {

@u16air{reference: mul};

col fixed U16 = [0..TWOBYTES];
lookup_proves(OPIDS[1], [U16], mul, PIOP_NAME_RANGE_CHECK);
col fixed U16 = [0..STD_TWO_BYTES];
lookup_proves(DEFAULT_OPIDS[1], [U16], mul, PIOP_NAME_RANGE_CHECK);
}

airtemplate SpecifiedRanges(const int N, const int opids[], const int opids_count, const int predefineds[], const int mins[], const int maxs[]) {
Expand Down Expand Up @@ -98,40 +119,58 @@ airtemplate SpecifiedRanges(const int N, const int opids[], const int opids_coun
function range_check(expr colu, int min, int max, expr sel = 1, int predefined = 1) {
range_validator(min, max);

@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0};

if (min < 0) {
println(`The provided min=${min} is negative. Falling back to specified range...`);
} else if (max > TWOBYTES) {
println(`The provided max=${max} is greater than the maximum predefined ${TWOBYTES}. Falling back to specified range...`);
} else if (max > STD_TWO_BYTES) {
println(`The provided max=${max} is greater than the maximum predefined ${STD_TWO_BYTES}. Falling back to specified range...`);
}

const int opid = opid_process(min, max, predefined);

// Check if the range can be absorbed into the predefined ranges
const int absorb = predefined && min >= 0 && max <= TWOBYTES;
const int absorb = predefined && min >= 0 && max <= STD_TWO_BYTES;

// Define the assume
if (absorb) {
if (min == 0 && (max == BYTE || max == TWOBYTES)) {
const int is_u8 = max == BYTE ? 1 : 0;
if (min == 0) {
if (max == STD_BYTE) {
// Here, we should decide whether to use the U8 or the U8xU8 table, using the latter by default to optimize the number of lookups
container air.std.rc.u8ext alias rc_u8ext {
int rc_u8ext_ncols = 0;
expr rc_u8ext_cols[100];
expr rc_u8ext_sels[100];
}

rc_u8ext.rc_u8ext_cols[rc_u8ext.rc_u8ext_ncols] = colu;
rc_u8ext.rc_u8ext_sels[rc_u8ext.rc_u8ext_ncols] = sel;
rc_u8ext.rc_u8ext_ncols++;

on final air define_assumes_u8ext();
return;
} else if (max == STD_TWO_BYTES) {
@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0, type: "U16"};

airgroup.std.rc.u8_used = airgroup.std.rc.u8_used || is_u8;
airgroup.std.rc.u16_used = airgroup.std.rc.u16_used || 1-is_u8;
lookup_assumes(opid, [colu], sel, PIOP_NAME_RANGE_CHECK);
airgroup.std.rc.u16_used = 1;

lookup_assumes(opid, [colu], sel, PIOP_NAME_RANGE_CHECK);
}
} else {
// Here, we need to reuse to some of the default ranges depending
// on the values of min and max
if (max <= BYTE) {
// reuse U8
airgroup.std.rc.u8_used = 1;
if (max <= STD_BYTE) {
@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0, type: "U8ExtDouble"};

// first prove that colu - min is in U8
lookup_assumes(opid, [colu - min], sel, PIOP_NAME_RANGE_CHECK);
// Here, the range is of the form [a,b], with a >= 0 and b <= STD_BYTE, except for the range [0,STD_BYTE]

// To avoid two lookups, we use the U8xU8 air
airgroup.std.rc.u8ext_used = 1;

// colu is in [a,b] iff colu-min is in [0,STD_BYTE] and max-colu is in [0,STD_BYTE]
// iff [colu-min,max-colu] is in [0,STD_BYTE]x[0,STD_BYTE]
lookup_assumes(opid, [colu - min, max - colu], sel, PIOP_NAME_RANGE_CHECK);
} else if (max <= STD_TWO_BYTES) {
@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0, type: "U16Double"};

// Here, the range is of the form [a,b], with a >= 0 and STD_BYTE < b <= STD_TWO_BYTES, except for the range [0,STD_TWO_BYTES]

// then prove that max - colu is in U8
lookup_assumes(opid, [max - colu], sel, PIOP_NAME_RANGE_CHECK);
} else if (max <= TWOBYTES) {
// reuse U16
airgroup.std.rc.u16_used = 1;

Expand All @@ -143,6 +182,8 @@ function range_check(expr colu, int min, int max, expr sel = 1, int predefined =
}
}
} else {
@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0, type: "Specified"};

lookup_assumes(opid, [colu], sel, PIOP_NAME_RANGE_CHECK);
}

Expand Down Expand Up @@ -183,8 +224,8 @@ function multi_range_check(expr colu, int min1, int max1, int min2, int max2, ex
range_validator(min1, max1);
range_validator(min2, max2);

@range_def{predefined: predefined, min: min1, max: max1, min_neg: min1 < 0 , max_neg: max1 < 0};
@range_def{predefined: predefined, min: min2, max: max2, min_neg: min2 < 0 , max_neg: max2 < 0};
@range_def{predefined: predefined, min: min1, max: max1, min_neg: min1 < 0 , max_neg: max1 < 0, type: "Specified"};
@range_def{predefined: predefined, min: min2, max: max2, min_neg: min2 < 0 , max_neg: max2 < 0, type: "Specified"};

const int opid1 = opid_process(min1, max1, predefined);
const int opid2 = opid_process(min2, max2, predefined);
Expand Down Expand Up @@ -220,7 +261,7 @@ function range_check_id(const int min, const int max, const int predefined = 0):

range_validator(min, max);

@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0};
@range_def{predefined: predefined, min: min, max: max, min_neg: min < 0, max_neg: max < 0, type: "Specified"};

container air.std.rcid alias rcid {
int opids_count_id = 0;
Expand Down Expand Up @@ -291,13 +332,15 @@ private function range_validator(const int min, const int max) {

private function opid_process(const int min, const int max, const int predefined): int {
container proof.std.rc alias rcproof {
// Number of times U8, U16 and specified ranges air are used
// Number of times defined range airs are used
int num_u8_airgroup = 0;
int num_u8ext_airgroup = 0;
int num_u16_airgroup = 0;
int num_spec_airgroup = 0;

// Last airgroup id that uses U8, U16 and specified ranges air
// Last airgroup id that defined range airs used
int max_u8_airgroup_id = 0;
int max_u8ext_airgroup_id = 0;
int max_u16_airgroup_id = 0;
int max_spec_airgroup_id = 0;

Expand All @@ -315,6 +358,7 @@ private function opid_process(const int min, const int max, const int predefined
container airgroup.std.rc {
// To mark if the U8 and U16 airs are used within the airgroup
int u8_used = 0;
int u8ext_used = 0;
int u16_used = 0;
}

Expand All @@ -329,7 +373,7 @@ private function opid_process(const int min, const int max, const int predefined
const int opid = get_opid(min, max, predefined);

// Exit if the range does not belong to the specified ranges air
if (opid <= OPIDS[1]) {
if (opid <= DEFAULT_OPIDS[length(DEFAULT_OPIDS) - 1]) {
return opid;
}

Expand All @@ -341,13 +385,46 @@ private function opid_process(const int min, const int max, const int predefined
rcproof.opids_count++;

// if the opid is not predefined and the range is bigger than the current specified_N, we update it
if (opid > OPIDS[1] && max - min > rcproof.specified_N) {
if (opid > DEFAULT_OPIDS[length(DEFAULT_OPIDS) - 1] && max - min > rcproof.specified_N) {
rcproof.specified_N = max - min;
}

return opid;
}

private function define_assumes_u8ext() {
use air.std.rc.u8ext;

@range_def{predefined: 1, min: 0, max: STD_BYTE, min_neg: 0, max_neg: 0, type: rc_u8ext_ncols == 1 ? "U8" : "U8Ext"};

if (rc_u8ext_ncols == 1) {
// If there is only one column, we use the U8 air
airgroup.std.rc.u8_used = 1;

lookup_assumes(DEFAULT_OPIDS[0], [rc_u8ext_cols[0]], rc_u8ext_sels[0], PIOP_NAME_RANGE_CHECK);
} else if (rc_u8ext_ncols > 1) {
// If there are more than one column, we use the U8xU8 air
airgroup.std.rc.u8ext_used = 1;

for (int i = 0; i < rc_u8ext_ncols/2; i++) {
// Hint how the columns are grouped
@u8airext_cols{first_column: rc_u8ext_cols[2*i], second_column: rc_u8ext_cols[2*i+1]};

lookup_assumes(DEFAULT_OPIDS[2], [rc_u8ext_cols[2*i], rc_u8ext_cols[2*i+1]], rc_u8ext_sels[2*i]*rc_u8ext_sels[2*i+1], PIOP_NAME_RANGE_CHECK);
}

// Reuse the U8xU8 air for the last column if the number of columns is odd
if (rc_u8ext_ncols % 2 == 1) {
// Hint how the columns are grouped
@u8airext_cols{first_column: rc_u8ext_cols[rc_u8ext_ncols-1], second_column: 0};

lookup_assumes(DEFAULT_OPIDS[2], [rc_u8ext_cols[rc_u8ext_ncols-1], 0], rc_u8ext_sels[rc_u8ext_ncols-1], PIOP_NAME_RANGE_CHECK);
}
}

define_proves(absorb: 1);
}

private function define_proves(const int absorb) {
use proof.std.rc;

Expand All @@ -359,6 +436,13 @@ private function define_proves(const int absorb) {
}
}

// If the U8Ext was used, update the max airgroup id
if (airgroup.std.rc.u8ext_used) {
if (max_u8ext_airgroup_id < AIRGROUP_ID) {
max_u8ext_airgroup_id = AIRGROUP_ID;
}
}

// If the U16 was used, update the max airgroup id
if (airgroup.std.rc.u16_used) {
if (proof.std.rc.max_u16_airgroup_id < AIRGROUP_ID) {
Expand Down Expand Up @@ -387,13 +471,22 @@ private function declarePreRangeAir() {
num_u8_airgroup++;
}

// If the U8Ext was used in the airgroup, update the number of U8Ext airgroups
if (airgroup.std.rc.u8ext_used) {
num_u8ext_airgroup++;
}

// If the U16 was used in the airgroup, update the number of U16 airgroups
if (airgroup.std.rc.u16_used) {
num_u16_airgroup++;
}

// The U8 and U16 airs are only needed once, so we wait for the last airgroup that uses them
if (AIRGROUP_ID != max_u8_airgroup_id && AIRGROUP_ID != max_u16_airgroup_id) {
// The U8, U8ext and U16 airs are only needed once, so we wait for the last airgroup that uses them
if (
AIRGROUP_ID != max_u8_airgroup_id &&
AIRGROUP_ID != max_u8ext_airgroup_id &&
AIRGROUP_ID != max_u16_airgroup_id
) {
return;
}

Expand All @@ -415,6 +508,23 @@ private function declarePreRangeAir() {
}
}

if (AIRGROUP_ID == max_u8ext_airgroup_id && num_u8ext_airgroup > 0) {
container proof.std.u8ext {
int airgroup_id = 0;
int air_id = 0;
}

if (num_u8ext_airgroup == 1){
// If the U8AirExtended is needed only once, we instantiate it in the (single) callable airgroup
U8AirExtended();
} else {
// If the U8AirExtended is needed more than once, we instantiate it in its own airgroup
airgroup U8AirExtended {
U8AirExtended();
}
}
}

if (AIRGROUP_ID == max_u16_airgroup_id && num_u16_airgroup > 0) {
container proof.std.u16 {
int airgroup_id = 0;
Expand Down Expand Up @@ -467,6 +577,10 @@ private function createPreMetadata() {
@u8air{airgroup_id: proof.std.u8.airgroup_id, air_id: proof.std.u8.air_id};
}

if (proof.std.rc.num_u8ext_airgroup > 0) {
@u8airext{airgroup_id: proof.std.u8ext.airgroup_id, air_id: proof.std.u8ext.air_id};
}

if (proof.std.rc.num_u16_airgroup > 0) {
@u16air{airgroup_id: proof.std.u16.airgroup_id, air_id: proof.std.u16.air_id};
}
Expand Down
8 changes: 4 additions & 4 deletions pil2-components/lib/std/pil/std_sum.pil
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ private function update_piop_sum(int name, int proves, int opid[], expr sumid, e
// on final air find_repeated_proves();

// define constraints at the air level
on final air piop_gsum_air();
on final(-1) air piop_gsum_air();

// update the contributions at the airgroup level
on final airgroup piop_gsum_airgroup();
on final(-1) airgroup piop_gsum_airgroup();

// at the end, check consistency of all the opids
on final proof check_opids_were_completed_sum();
on final(-1) proof check_opids_were_completed_sum();
}

// adds the global constraint
on final proof piop_gsum_proof();
on final(-1) proof piop_gsum_proof();
}

/**
Expand Down
2 changes: 2 additions & 0 deletions pil2-components/lib/std/rs/src/range_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ mod specified_ranges;
mod std_range_check;
mod u16air;
mod u8air;
mod u8air_extended;

pub use range::*;
pub use specified_ranges::*;
pub use std_range_check::*;
pub use u16air::*;
pub use u8air::*;
pub use u8air_extended::*;
Loading
Loading