Skip to content

Commit

Permalink
core: add support for multiplication of secp256k1 points by scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
apoelstra committed May 17, 2017
1 parent 6046834 commit 21937db
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 0 deletions.
114 changes: 114 additions & 0 deletions opentimestamps/core/secp256k1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2017 The OpenTimestamps developers
#
# This file is part of python-opentimestamps.
#
# It is subject to the license terms in the LICENSE file found in the top-level
# directory of this distribution.
#
# No part of python-opentimestamps including this file, may be copied,
# modified, propagated, or distributed except according to the terms contained
# in the LICENSE file.

## What follows is a lot of inefficient but explicit secp256k1 math
class Point(object):
inf = True
x = 0
y = 0

def __init__(self, x=0, y=0):
self.x = x
self.y = y
if x == 0 and y == 0:
self.inf = True
else:
self.inf = False

def __repr__(self):
if self.inf:
return "Point(infinity)"
else:
return "Point(%x, %x)" % (self.x, self.y)

def __eq__(self, other):
if isinstance(other, self.__class__):
return (self.inf == True and other.inf == True) or\
(self.inf == False and other.inf == False and self.x == other.x and self.y == other.y)
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

@staticmethod
def decode(data):
if len(data) != 33 or (data[0] != 2 and data[0] != 3):
raise MsgValueError("Incorrectly formatted public key")

x = int.from_bytes(data[1:], 'big')
if x >= SECP256K1_P:
raise MsgValueError("out of range x coordinate for secp256k1 point")

ysqr = (x ** 3 + 7) % SECP256K1_P
y = psqrt(ysqr)
if pow(y, 2, SECP256K1_P) != ysqr:
raise MsgValueError("invalid x coordinate for secp256k1 point")

if y % 2 == 1 and data[0] == 2:
y = SECP256K1_P - y
if y % 2 == 0 and data[0] == 3:
y = SECP256K1_P - y

return Point(x, y)

def encode(self):
ret = bytearray(self.x.to_bytes(33, 'big'))
assert(ret[0] == 0)
if self.y % 2 == 1:
ret[0] = 3
else:
ret[0] = 2
return ret

def add(self, pt):
if self.inf:
return pt
if pt.inf:
return self

if self.x == pt.x:
if self.y == SECP256K1_P - pt.y:
return Point()
else:
assert(self.y == pt.y)
lam = (3 * self.x ** 2 * pinv(2 * self.y)) % SECP256K1_P
else:
lam = ((pt.y - self.y) * pinv(pt.x - self.x)) % SECP256K1_P

x3 = (lam ** 2 - self.x - pt.x) % SECP256K1_P
y3 = (self.y + lam * (x3 - self.x)) % SECP256K1_P

return Point(x3, SECP256K1_P - y3)

def scalar_mul(self, s):
ret = Point()
add = self
s = s % SECP256K1_N
while s > 0:
if s % 2 == 1:
ret = ret.add(add)
add = add.add(add) # add
s >>= 1
return ret

SECP256K1_P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
SECP256K1_N = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
SECP256K1_GEN = Point(0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798,
0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8)

def pinv(x):
return pow(x, SECP256K1_P - 2, SECP256K1_P)

def psqrt(x):
# using `>> 2` in place of `/ 4` keeps everything as an int rather than float
return pow(x, (SECP256K1_P + 1) >> 2, SECP256K1_P)

101 changes: 101 additions & 0 deletions opentimestamps/tests/core/test_secp256k1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (C) 2017 The OpenTimestamps developers
#
# This file is part of python-opentimestamps.
#
# It is subject to the license terms in the LICENSE file found in the top-level
# directory of this distribution.
#
# No part of python-opentimestamps including this file, may be copied,
# modified, propagated, or distributed except according to the terms contained
# in the LICENSE file.

import binascii
import unittest

from opentimestamps.core.secp256k1 import *

class Test_Secp256k1(unittest.TestCase):
def test_point_rt(self):
"""Point encoding round trip"""
gen = SECP256K1_GEN
encode = gen.encode()
self.assertEqual(encode, binascii.unhexlify("0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"))
gen2 = Point().decode(encode)
self.assertEqual(gen, gen2)

def test_pinv(self):
"""Field inversion mod p"""
self.assertEqual(pinv(1), 1)
self.assertEqual(pinv(2), 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe18)
self.assertEqual(pinv(3), 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa9fffffd75)
self.assertEqual(2, pinv(0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe18))
self.assertEqual(3, pinv(0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa9fffffd75))

def test_psqrt(self):
"""Field square root mod p"""
self.assertEqual(psqrt(1), 1)
self.assertEqual(psqrt(2), 0x210c790573632359b1edb4302c117d8a132654692c3feeb7de3a86ac3f3b53f7)
self.assertEqual(psqrt(4), 2)
# may return the sqrt or its negative
self.assertEqual(psqrt(9), SECP256K1_P - 3)
self.assertEqual(psqrt(49), SECP256K1_P - 7)

def test_point_add(self):
"""Point adding and doubling"""

inf = Point()
# P random chosen by dice roll
p1 = Point(0x394867ad93f5c9612e8d8b7600443334026e648e365337d799190e845d649e67,
0x0b84af9a00c1a55a7ac03917e59b21c68d1ffdf18720c3ad279077049cfaaf63)
# 2P
p2 = Point(0x8e6575f6c759aea04a8ec65f61f71eba237a0af54292d41e3a4bac2efa922dea,
0x2b3c07687787ff07ae312305f30481c451ae3b78d4f479a3b729615fedc040e4)
# -2P
np2 = Point(0x8e6575f6c759aea04a8ec65f61f71eba237a0af54292d41e3a4bac2efa922dea,
0xd4c3f897887800f851cedcfa0cfb7e3bae51c4872b0b865c48d69e9f123fbb4b)
# 3P
p3 = Point(0x53dd5e495c7404790f9347470cc9c38ee239809c758f02ec04ba641ab3d0e043,
0xd7a4f5e5bdf21000b1fe7216adbea92cb9917d8fea7b37628c1eddb409a5cd3f)

self.assertEqual(inf.add(inf), inf)
self.assertEqual(p1.add(inf), p1)
self.assertEqual(inf.add(p1), p1)
self.assertEqual(p1.add(p1), p2)
self.assertEqual(p1.add(p2), p3)
self.assertEqual(p2.add(p1), p3)
self.assertEqual(p3.add(np2), p1)
self.assertEqual(np2.add(p3), p1)
self.assertEqual(p2.add(np2), inf)
self.assertEqual(np2.add(p2), inf)

def test_scalar_mul(self):
inf = Point()
# P random chosen by dice roll
p1 = Point(0x394867ad93f5c9612e8d8b7600443334026e648e365337d799190e845d649e67,
0x0b84af9a00c1a55a7ac03917e59b21c68d1ffdf18720c3ad279077049cfaaf63)
# 2P
p2 = Point(0x8e6575f6c759aea04a8ec65f61f71eba237a0af54292d41e3a4bac2efa922dea,
0x2b3c07687787ff07ae312305f30481c451ae3b78d4f479a3b729615fedc040e4)
# -2P
np2 = Point(0x8e6575f6c759aea04a8ec65f61f71eba237a0af54292d41e3a4bac2efa922dea,
0xd4c3f897887800f851cedcfa0cfb7e3bae51c4872b0b865c48d69e9f123fbb4b)
# 3P
p3 = Point(0x53dd5e495c7404790f9347470cc9c38ee239809c758f02ec04ba641ab3d0e043,
0xd7a4f5e5bdf21000b1fe7216adbea92cb9917d8fea7b37628c1eddb409a5cd3f)

# nP
n = 0xa91ce154dcab9adabe08cc1ee84ec3cd0f426bbc08a54a1c41bd25f2587caedd
pn = Point(0x9dc4b057a857ad2ef3535b4a207a7bfc9264e8fcacf718c895db7ead8d445b26,
0x5af110ecb68636e5c352b69fc6348173932b83ca64587a91fd88af1446e33979)

self.assertEqual(inf.scalar_mul(0), inf)
self.assertEqual(inf.scalar_mul(1000), inf)
self.assertEqual(inf.scalar_mul(-1), inf)

self.assertEqual(p1.scalar_mul(0), inf)
self.assertEqual(p1.scalar_mul(1), p1)
self.assertEqual(p1.scalar_mul(2), p2)
self.assertEqual(p1.scalar_mul(-2), np2)
self.assertEqual(p2.scalar_mul(-1), np2)
self.assertEqual(p1.scalar_mul(3), p3)

0 comments on commit 21937db

Please sign in to comment.