-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy path_pysrp.py
136 lines (105 loc) · 4.24 KB
/
_pysrp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# N A large safe prime (N = 2q+1, where q is prime)
# All arithmetic is done modulo N.
# g A generator modulo N
# k Multiplier parameter (k = H(N, g) in SRP-6a, k = 3 for legacy SRP-6)
# s User's salt
# I Username
# p Cleartext Password
# H() One-way hash function
# ^ (Modular) Exponentiation
# u Random scrambling parameter
# a,b Secret ephemeral values
# A,B Public ephemeral values
# x Private key (derived from p and s)
# v Password verifier
from .pmhash import pmhash
from .util import (PM_VERSION, SRP_LEN_BYTES, SALT_LEN_BYTES, bytes_to_long, custom_hash,
get_random_of_length, hash_password, long_to_bytes)
def get_ng(n_bin, g_hex):
return bytes_to_long(n_bin), int(g_hex, 16)
def hash_k(hash_class, g, modulus, width):
h = hash_class()
h.update(g.to_bytes(width, 'little'))
h.update(modulus.to_bytes(width, 'little'))
return bytes_to_long(h.digest())
def calculate_x(hash_class, salt, password, modulus, version):
exp = hash_password(
hash_class, password, salt, long_to_bytes(modulus, SRP_LEN_BYTES), version
)
return bytes_to_long(exp)
def calculate_client_proof(hash_class, A, B, K):
h = hash_class()
h.update(long_to_bytes(A, SRP_LEN_BYTES))
h.update(long_to_bytes(B, SRP_LEN_BYTES))
h.update(K)
return h.digest()
def calculate_server_proof(hash_class, A, M, K):
h = hash_class()
h.update(long_to_bytes(A, SRP_LEN_BYTES))
h.update(M)
h.update(K)
return h.digest()
class User(object):
def __init__(self, password, n_bin, g_hex=b"2", bytes_a=None, bytes_A=None): # noqa
if bytes_a and len(bytes_a) != 32:
raise ValueError("32 bytes required for bytes_a")
if not isinstance(password, str) or len(password) == 0:
raise ValueError("Invalid password")
self.N, self.g = get_ng(n_bin, g_hex)
self.hash_class = pmhash
self.k = hash_k(
self.hash_class, self.g,
self.N, SRP_LEN_BYTES
)
self.p = password.encode()
self.a = bytes_to_long(bytes_a) if bytes_a else get_random_of_length(32)
self.A = bytes_to_long(bytes_A) if bytes_A else pow(self.g, self.a, self.N)
self.v = None
self.M = None
self.K = None
self.expected_server_proof = None
self._authenticated = False
self.bytes_s = None
self.S = None
self.B = None
self.u = None
self.x = None
def authenticated(self):
return self._authenticated
def get_ephemeral_secret(self):
return long_to_bytes(self.a, SRP_LEN_BYTES)
def get_session_key(self):
return self.K if self._authenticated else None
def get_challenge(self):
return long_to_bytes(self.A, SRP_LEN_BYTES)
# Returns M or None if SRP-6a safety check is violated
def process_challenge(
self, bytes_s, bytes_server_challenge, version=PM_VERSION
):
self.bytes_s = bytes_s
self.B = bytes_to_long(bytes_server_challenge)
# SRP-6a safety check
if (self.B % self.N) == 0:
return None
self.u = custom_hash(self.hash_class, self.A, self.B)
# SRP-6a safety check
if self.u == 0:
return None
self.x = calculate_x(self.hash_class, self.bytes_s, self.p, self.N, version)
self.v = pow(self.g, self.x, self.N)
self.S = pow(
(self.B - self.k * self.v), (self.a + self.u * self.x), self.N
)
self.K = long_to_bytes(self.S, SRP_LEN_BYTES)
self.M = calculate_client_proof(self.hash_class, self.A, self.B, self.K) # noqa
self.expected_server_proof = calculate_server_proof(
self.hash_class, self.A, self.M, self.K
)
return self.M
def verify_session(self, server_proof):
if self.expected_server_proof == server_proof:
self._authenticated = True
def compute_v(self, bytes_s=None, version=PM_VERSION):
self.bytes_s = long_to_bytes(get_random_of_length(SALT_LEN_BYTES), SALT_LEN_BYTES) if bytes_s is None else bytes_s # noqa
self.x = calculate_x(self.hash_class, self.bytes_s, self.p, self.N, version)
return self.bytes_s, long_to_bytes(pow(self.g, self.x, self.N), SRP_LEN_BYTES) # noqa