From 43d76f7a8846cd4f4f5e0f8e9146ab7d6298b72c Mon Sep 17 00:00:00 2001 From: Samarth Kulshreshtha Date: Thu, 6 Sep 2018 13:54:48 -0500 Subject: [PATCH] Add tests and refactor polynomial.py --- honeybadgermpc/polynomial.py | 132 ++++++++++------------------------- honeybadgermpc/rand_batch.py | 4 +- tests/test_polynomial.py | 56 +++++++++++++++ 3 files changed, 94 insertions(+), 98 deletions(-) create mode 100644 tests/test_polynomial.py diff --git a/honeybadgermpc/polynomial.py b/honeybadgermpc/polynomial.py index f0cdc5e5..52b3690a 100644 --- a/honeybadgermpc/polynomial.py +++ b/honeybadgermpc/polynomial.py @@ -1,8 +1,6 @@ import operator import random from functools import reduce -import sys -import time def strip_trailing_zeros(a): @@ -73,7 +71,7 @@ def evaluate_fft(self, omega, n): assert type(omega) is field assert omega ** n == 1, "must be an n'th root of unity" assert omega ** (n//2) != 1, "must be a primitive n'th root of unity" - return fft(self, n, omega) + return fft(self, omega, n) @classmethod def random(cls, degree, y0=None): @@ -82,6 +80,25 @@ def random(cls, degree, y0=None): coeffs[0] = y0 return cls(coeffs) + @classmethod + def interp_extrap(cls, xs, omega): + """ + Interpolates the polynomial based on the even points omega^2i + then evaluates at all points omega^i + """ + n = len(xs) + assert n & (n-1) == 0, "n must be power of 2" + assert pow(omega, 2*n) == 1, "omega must be 2n'th root of unity" + assert pow(omega, n) != 1, "omega must be primitive 2n'th root of unity" + + # Interpolate the polynomial up to degree n + poly = cls.interpolate_fft(xs, omega**2) + + # Evaluate the polynomial + xs2 = poly.evaluate_fft(omega, 2*n) + + return xs2 + _poly_cache[field] = Polynomial return Polynomial @@ -94,17 +111,15 @@ def get_omega(field, n, seed=None): This only makes sense if n is a power of 2! """ - p = field.modulus + assert n & n-1 == 0, "n must be a power of 2" if seed is not None: random.seed(seed) - x = field(random.randint(0, p-1)) - y = pow(x, (p-1)//n) - if y == 1: + x = field(random.randint(0, field.modulus-1)) + y = pow(x, (field.modulus-1)//n) + if y == 1 or pow(y, n//2) == 1: return get_omega(field, n) - if pow(y, n//2) == 1: - return get_omega(field, n) - assert pow(y, n) == 1 - assert pow(y, n//2) != 1 + assert pow(y, n) == 1, "omega must be 2n'th root of unity" + assert pow(y, n//2) != 1, "omega must be primitive 2n'th root of unity" return y @@ -117,15 +132,12 @@ def fft_helper(A, omega, field): list is of the form [a0, a1, ... , an]. """ n = len(A) - - # Check if n is a power of 2. - assert not (n & (n-1)) + assert not (n & (n-1)), "n must be a power of 2" if n == 1: return A - B = A[0::2] - C = A[1::2] + B, C = A[0::2], A[1::2] B_bar = fft_helper(B, pow(omega, 2), field) C_bar = fft_helper(C, pow(omega, 2), field) A_bar = [field(1)]*(n) @@ -135,86 +147,14 @@ def fft_helper(A, omega, field): return A_bar -def nearest_power_of_two(x): - return 2**x.bit_length() - - -def fft(poly, n=None, omega=None, seed=None, test=False, enable_profiling=False): - - coefficients = poly.coeffs - Zp = poly.field - n = nearest_power_of_two(max(len(coefficients), n)-1) - - padded_coefficients = coefficients + ([0] * (n-len(coefficients))) - - if omega is None: - s = time.time() - omega = get_omega(Zp, n, seed) - e = time.time() - - if enable_profiling: - print("OMEGA", omega, "TIME:", e - s) - if test: - assert pow(omega, n) == 1 - # assert pow(omega,n//2) != 1 - - s = time.time() - output = fft_helper(padded_coefficients, omega, Zp) - e = time.time() - - if enable_profiling: - print("FFT completed", "TIME:", e - s) - - if test: - test_correctness(poly, omega, output) - - return output - - -def test_correctness(poly, omega, fft_helper_result): - """ - This method verifies the output generated by FFT by evaluating - the polynomial at the coefficients. - - In order to save time, it verifies only 100 points from the FFT output - after evaluating them at the coefficients. - """ - assert len(poly.coeffs) <= len(fft_helper_result) - n = len(fft_helper_result) - - # Test on 100 random points - total_verification_points = 100 - sample = [random.randint(0, n-1) for i in range(total_verification_points)] - - c = 0 - for i in sample: - y = poly(omega**i) - c += 1 - sys.stdout.write("%d / %d points verified!" % (c, - total_verification_points)) - char = "\r" if c < len(sample) else "\n" - sys.stdout.write(char) - sys.stdout.flush() - assert y == fft_helper_result[i] - - -def interp_extrap(Poly, xs, omega): - """ - Interpolates the polynomial based on the even points omega^2i - then evaluates at all points omega^i - """ - n = len(xs) - assert n & (n-1) == 0, "n must be power of 2" - assert pow(omega, 2*n) == 1, "omega must be 2n'th root of unity" - assert pow(omega, n) != 1, "omega must be primitive 2n'th root of unity" - - # Interpolate the polynomial up to degree n - poly = Poly.interpolate_fft(xs, omega**2) - - # Evaluate the polynomial - xs2 = poly.evaluate_fft(omega, 2*n) +def fft(poly, omega, n, seed=None): + assert n & n-1 == 0, "n must be a power of 2" + assert len(poly.coeffs) <= n + assert pow(omega, n) == 1 + assert pow(omega, n//2) != 1 - return xs2 + paddedCoeffs = poly.coeffs + ([poly.field(0)] * (n-len(poly.coeffs))) + return fft_helper(paddedCoeffs, omega, poly.field) if __name__ == "__main__": @@ -242,7 +182,7 @@ def interp_extrap(Poly, xs, omega): for i in range(len(x)): print(omega**(2*i), x[i]) print('interp_extrap:') - x3 = interp_extrap(Poly, x, omega) + x3 = Poly.interp_extrap(x, omega) for i in range(len(x3)): print(omega**i, x3[i]) diff --git a/honeybadgermpc/rand_batch.py b/honeybadgermpc/rand_batch.py index 5b7bdaf2..9a68fb3c 100644 --- a/honeybadgermpc/rand_batch.py +++ b/honeybadgermpc/rand_batch.py @@ -1,7 +1,7 @@ import asyncio import random from .field import GF -from .polynomial import polynomialsOver, interp_extrap, get_omega +from .polynomial import polynomialsOver, get_omega # Fix the field for now Field = GF(0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001) @@ -97,7 +97,7 @@ def nearest_power_of_two(x): return 2**(x-1).bit_length() # Round up # Interpolate all the committed shares omega = get_omega(Field, 2*D, seed=0) - outputs = interp_extrap(Poly, input_shares[:D], omega) + outputs = Poly.interp_extrap(input_shares[:D], omega) output_shares = outputs[1:2*((sum(valid)-f)*B):2] # Pick the odd shares print('output_shares:', len(output_shares)) diff --git a/tests/test_polynomial.py b/tests/test_polynomial.py new file mode 100644 index 00000000..980c0508 --- /dev/null +++ b/tests/test_polynomial.py @@ -0,0 +1,56 @@ +from random import randint +from honeybadgermpc.polynomial import get_omega + + +def test_poly_eval_at_k(GaloisField, Polynomial): + poly1 = Polynomial([0, 1]) # y = x + for i in range(10): + assert poly1(i) == i + + poly2 = Polynomial([10, 0, 1]) # y = x^2 + 10 + for i in range(10): + assert poly2(i) == pow(i, 2) + 10 + + d = randint(1, 50) + coeffs = [randint(0, GaloisField.modulus-1) for i in range(d)] + poly3 = Polynomial(coeffs) # random polynomial of degree d + x = randint(0, GaloisField.modulus-1) + y = sum([pow(x, i) * a for i, a in enumerate(coeffs)]) + assert y == poly3(x) + + +def test_evaluate_fft(GaloisField, Polynomial): + d = randint(210, 300) + coeffs = [randint(0, GaloisField.modulus-1) for i in range(d)] + poly = Polynomial(coeffs) # random polynomial of degree d + n = len(poly.coeffs) + n = n if n & n-1 == 0 else 2**n.bit_length() + omega = get_omega(GaloisField, n) + fftResult = poly.evaluate_fft(omega, n) + assert len(fftResult) == n + for i, a in zip(range(1, 201, 2), fftResult[1:201:2]): # verify only 100 points + assert poly(pow(omega, i)) == a + + +def test_interpolate_fft(GaloisField, Polynomial): + d = randint(210, 300) + y = [randint(0, GaloisField.modulus-1) for i in range(d)] + n = len(y) + n = n if n & n-1 == 0 else 2**n.bit_length() + ys = y + [GaloisField(0)] * (n - len(y)) + omega = get_omega(GaloisField, n) + poly = Polynomial.interpolate_fft(ys, omega) + for i, a in zip(range(1, 201, 2), ys[1:201:2]): # verify only 100 points + assert poly(pow(omega, i)) == a + + +def test_interp_extrap(GaloisField, Polynomial): + d = randint(210, 300) + y = [randint(0, GaloisField.modulus-1) for i in range(d)] + n = len(y) + n = n if n & n-1 == 0 else 2**n.bit_length() + ys = y + [GaloisField(0)] * (n - len(y)) + omega = get_omega(GaloisField, 2*n) + values = Polynomial.interp_extrap(ys, omega) + for a, b in zip(ys, values[0:201:2]): # verify only 100 points + assert a == b