Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests and refactor polynomial.py #51

Merged
merged 1 commit into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 36 additions & 96 deletions honeybadgermpc/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import operator
import random
from functools import reduce
import sys
import time


def strip_trailing_zeros(a):
Expand Down Expand Up @@ -73,7 +71,7 @@ def evaluate_fft(self, omega, n):
assert type(omega) is field
assert omega ** n == 1, "must be an n'th root of unity"
assert omega ** (n//2) != 1, "must be a primitive n'th root of unity"
return fft(self, n, omega)
return fft(self, omega, n)

@classmethod
def random(cls, degree, y0=None):
Expand All @@ -82,6 +80,25 @@ def random(cls, degree, y0=None):
coeffs[0] = y0
return cls(coeffs)

@classmethod
def interp_extrap(cls, xs, omega):
"""
Interpolates the polynomial based on the even points omega^2i
then evaluates at all points omega^i
"""
n = len(xs)
assert n & (n-1) == 0, "n must be power of 2"
assert pow(omega, 2*n) == 1, "omega must be 2n'th root of unity"
assert pow(omega, n) != 1, "omega must be primitive 2n'th root of unity"
smkuls marked this conversation as resolved.
Show resolved Hide resolved

# Interpolate the polynomial up to degree n
poly = cls.interpolate_fft(xs, omega**2)

# Evaluate the polynomial
xs2 = poly.evaluate_fft(omega, 2*n)

return xs2

_poly_cache[field] = Polynomial
return Polynomial

Expand All @@ -94,17 +111,15 @@ def get_omega(field, n, seed=None):

This only makes sense if n is a power of 2!
"""
p = field.modulus
assert n & n-1 == 0, "n must be a power of 2"
smkuls marked this conversation as resolved.
Show resolved Hide resolved
if seed is not None:
random.seed(seed)
x = field(random.randint(0, p-1))
y = pow(x, (p-1)//n)
if y == 1:
x = field(random.randint(0, field.modulus-1))
y = pow(x, (field.modulus-1)//n)
if y == 1 or pow(y, n//2) == 1:
return get_omega(field, n)
if pow(y, n//2) == 1:
return get_omega(field, n)
assert pow(y, n) == 1
assert pow(y, n//2) != 1
assert pow(y, n) == 1, "omega must be 2n'th root of unity"
assert pow(y, n//2) != 1, "omega must be primitive 2n'th root of unity"
smkuls marked this conversation as resolved.
Show resolved Hide resolved
return y


Expand All @@ -117,15 +132,12 @@ def fft_helper(A, omega, field):
list is of the form [a0, a1, ... , an].
"""
n = len(A)

# Check if n is a power of 2.
assert not (n & (n-1))
assert not (n & (n-1)), "n must be a power of 2"
smkuls marked this conversation as resolved.
Show resolved Hide resolved

if n == 1:
return A

B = A[0::2]
C = A[1::2]
B, C = A[0::2], A[1::2]
B_bar = fft_helper(B, pow(omega, 2), field)
C_bar = fft_helper(C, pow(omega, 2), field)
A_bar = [field(1)]*(n)
Expand All @@ -135,86 +147,14 @@ def fft_helper(A, omega, field):
return A_bar


def nearest_power_of_two(x):
return 2**x.bit_length()


def fft(poly, n=None, omega=None, seed=None, test=False, enable_profiling=False):

coefficients = poly.coeffs
Zp = poly.field
n = nearest_power_of_two(max(len(coefficients), n)-1)

padded_coefficients = coefficients + ([0] * (n-len(coefficients)))

if omega is None:
s = time.time()
omega = get_omega(Zp, n, seed)
e = time.time()

if enable_profiling:
print("OMEGA", omega, "TIME:", e - s)
if test:
assert pow(omega, n) == 1
# assert pow(omega,n//2) != 1

s = time.time()
output = fft_helper(padded_coefficients, omega, Zp)
e = time.time()

if enable_profiling:
print("FFT completed", "TIME:", e - s)

if test:
test_correctness(poly, omega, output)

return output


def test_correctness(poly, omega, fft_helper_result):
"""
This method verifies the output generated by FFT by evaluating
the polynomial at the coefficients.

In order to save time, it verifies only 100 points from the FFT output
after evaluating them at the coefficients.
"""
assert len(poly.coeffs) <= len(fft_helper_result)
n = len(fft_helper_result)

# Test on 100 random points
total_verification_points = 100
sample = [random.randint(0, n-1) for i in range(total_verification_points)]

c = 0
for i in sample:
y = poly(omega**i)
c += 1
sys.stdout.write("%d / %d points verified!" % (c,
total_verification_points))
char = "\r" if c < len(sample) else "\n"
sys.stdout.write(char)
sys.stdout.flush()
assert y == fft_helper_result[i]


def interp_extrap(Poly, xs, omega):
"""
Interpolates the polynomial based on the even points omega^2i
then evaluates at all points omega^i
"""
n = len(xs)
assert n & (n-1) == 0, "n must be power of 2"
assert pow(omega, 2*n) == 1, "omega must be 2n'th root of unity"
assert pow(omega, n) != 1, "omega must be primitive 2n'th root of unity"

# Interpolate the polynomial up to degree n
poly = Poly.interpolate_fft(xs, omega**2)

# Evaluate the polynomial
xs2 = poly.evaluate_fft(omega, 2*n)
def fft(poly, omega, n, seed=None):
assert n & n-1 == 0, "n must be a power of 2"
assert len(poly.coeffs) <= n
assert pow(omega, n) == 1
assert pow(omega, n//2) != 1
smkuls marked this conversation as resolved.
Show resolved Hide resolved

return xs2
paddedCoeffs = poly.coeffs + ([poly.field(0)] * (n-len(poly.coeffs)))
return fft_helper(paddedCoeffs, omega, poly.field)


if __name__ == "__main__":
Expand Down Expand Up @@ -242,7 +182,7 @@ def interp_extrap(Poly, xs, omega):
for i in range(len(x)):
print(omega**(2*i), x[i])
print('interp_extrap:')
x3 = interp_extrap(Poly, x, omega)
x3 = Poly.interp_extrap(x, omega)
for i in range(len(x3)):
print(omega**i, x3[i])

Expand Down
4 changes: 2 additions & 2 deletions honeybadgermpc/rand_batch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import random
from .field import GF
from .polynomial import polynomialsOver, interp_extrap, get_omega
from .polynomial import polynomialsOver, get_omega

# Fix the field for now
Field = GF(0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001)
Expand Down Expand Up @@ -97,7 +97,7 @@ def nearest_power_of_two(x): return 2**(x-1).bit_length() # Round up

# Interpolate all the committed shares
omega = get_omega(Field, 2*D, seed=0)
outputs = interp_extrap(Poly, input_shares[:D], omega)
outputs = Poly.interp_extrap(input_shares[:D], omega)
output_shares = outputs[1:2*((sum(valid)-f)*B):2] # Pick the odd shares
print('output_shares:', len(output_shares))

Expand Down
56 changes: 56 additions & 0 deletions tests/test_polynomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from random import randint
from honeybadgermpc.polynomial import get_omega


def test_poly_eval_at_k(GaloisField, Polynomial):
poly1 = Polynomial([0, 1]) # y = x
for i in range(10):
assert poly1(i) == i

poly2 = Polynomial([10, 0, 1]) # y = x^2 + 10
for i in range(10):
assert poly2(i) == pow(i, 2) + 10

d = randint(1, 50)
coeffs = [randint(0, GaloisField.modulus-1) for i in range(d)]
poly3 = Polynomial(coeffs) # random polynomial of degree d
x = randint(0, GaloisField.modulus-1)
y = sum([pow(x, i) * a for i, a in enumerate(coeffs)])
assert y == poly3(x)


def test_evaluate_fft(GaloisField, Polynomial):
d = randint(210, 300)
coeffs = [randint(0, GaloisField.modulus-1) for i in range(d)]
poly = Polynomial(coeffs) # random polynomial of degree d
n = len(poly.coeffs)
n = n if n & n-1 == 0 else 2**n.bit_length()
omega = get_omega(GaloisField, n)
fftResult = poly.evaluate_fft(omega, n)
assert len(fftResult) == n
for i, a in zip(range(1, 201, 2), fftResult[1:201:2]): # verify only 100 points
assert poly(pow(omega, i)) == a


def test_interpolate_fft(GaloisField, Polynomial):
d = randint(210, 300)
y = [randint(0, GaloisField.modulus-1) for i in range(d)]
n = len(y)
n = n if n & n-1 == 0 else 2**n.bit_length()
ys = y + [GaloisField(0)] * (n - len(y))
omega = get_omega(GaloisField, n)
poly = Polynomial.interpolate_fft(ys, omega)
for i, a in zip(range(1, 201, 2), ys[1:201:2]): # verify only 100 points
assert poly(pow(omega, i)) == a


def test_interp_extrap(GaloisField, Polynomial):
d = randint(210, 300)
y = [randint(0, GaloisField.modulus-1) for i in range(d)]
n = len(y)
n = n if n & n-1 == 0 else 2**n.bit_length()
ys = y + [GaloisField(0)] * (n - len(y))
omega = get_omega(GaloisField, 2*n)
values = Polynomial.interp_extrap(ys, omega)
for a, b in zip(ys, values[0:201:2]): # verify only 100 points
assert a == b