Skip to content

Commit

Permalink
Fixed python3 migration multiple errors
Browse files Browse the repository at this point in the history
  • Loading branch information
RedRaysTeam committed Sep 3, 2024
1 parent d3ea82b commit b02d6e8
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 107 deletions.
35 changes: 20 additions & 15 deletions pysap/SAPCredv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def decrypt_MSCryptProtect(plain, cred):
:rtype: string
"""
entropy = cred.pse_path
return dpapi_decrypt_blob(unhexlify(plain.blob.val), entropy)
return dpapi_decrypt_blob(unhexlify(plain.blob.val).decode(), entropy)

PROVIDER_MSCryptProtect = "MSCryptProtect"
"""Provider for Windows hosts using DPAPI"""
Expand Down Expand Up @@ -170,14 +170,14 @@ def lps_type_str(self):
@property
def cipher_format_version(self):
cipher = self.cipher.val_readable
if len(cipher) >= 36 and ord(cipher[0]) in [0, 1]:
return ord(cipher[0])
if len(cipher) >= 36 and (cipher[0]) in [0, 1]:
return cipher[0]
return 0

@property
def cipher_algorithm(self):
if self.cipher_format_version == 1:
return ord(self.cipher.val_readable[1])
return self.cipher.val_readable[1]
return 0

def decrypt(self, username):
Expand Down Expand Up @@ -216,34 +216,39 @@ def decrypt_simple(self, username):
iv = "\x00" * 8

# Decrypt the cipher text with the derived key and IV
decryptor = Cipher(algorithms.TripleDES(key), modes.CBC(iv), backend=default_backend()).decryptor()
decryptor = Cipher(algorithms.TripleDES(key.encode()), modes.CBC(iv.encode()), backend=default_backend()).decryptor()
plain = decryptor.update(blob) + decryptor.finalize()

return SAPCredv2_Cred_Plain(plain)

def xor(self, string, start):
"""XOR a given string using a fixed key and a starting number."""
def xor(self, data, start):
"""XOR given data using a fixed key and a starting number."""
key = 0x15a4e35
x = start
y = ""
for c in string:
y = bytearray()
for value in data:
x *= key
x += 1
y += chr(ord(c) ^ (x & 0xff))
# Ensure the value is treated as an integer
if isinstance(value, str):
value = ord(value)
# Perform the XOR operation
xored_value = value ^ (x & 0xff)
y.append(xored_value)
return y

def derive_key(self, key, blob, header, username):
"""Derive a key using SAP's algorithm. The key is derived using SHA256 and xor from an
initial key, a header, salt and username.
"""
digest = Hash(SHA256(), backend=default_backend())
digest.update(key)
digest.update(key.encode())
digest.update(blob[0:4])
digest.update(header.salt)
digest.update(self.xor(username, ord(header.salt[0])))
digest.update("" * 0x20)
digest.update(self.xor(username, (header.salt[0])))
digest.update(b"" * 0x20)
hashed = digest.finalize()
derived_key = self.xor(hashed, ord(header.salt[1]))
derived_key = self.xor(hashed, (header.salt[1]))

# Validate and select proper algorithm
if header.algorithm == CIPHER_ALGORITHM_3DES:
Expand Down Expand Up @@ -377,7 +382,7 @@ def decrypt(self, username=None):
plain = cipher.decrypt()

# Get the pin from the raw data
plain_size = ord(plain[0])
plain_size = (plain[0])
pin = plain[plain_size + 1:]

# Create a plain credential container
Expand Down
6 changes: 3 additions & 3 deletions pysap/SAPLPS.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def decrypt(self):

# Decrypt the cipher text with the encryption key
iv = "\x00" * 16
decryptor = Cipher(algorithms.AES(encryption_key), modes.CBC(iv)).decryptor()
decryptor = Cipher(algorithms.AES(encryption_key), modes.CBC(iv.encode())).decryptor()
plain = decryptor.update(self.encrypted_data) + decryptor.finalize()

# TODO: Calculate and validate HMAC
Expand Down Expand Up @@ -141,15 +141,15 @@ def decrypt_encryption_key_fallback(self):
log_lps.debug("Obtaining encryption key with FALLBACK LPS mode")

digest = Hash(SHA1())
digest.update(cred_key_lps_fallback)
digest.update(cred_key_lps_fallback.encode())
hashed_key = digest.finalize()

hmac = HMAC(hashed_key, SHA1())
hmac.update(self.context)
default_key = hmac.finalize()[:16]

iv = "\x00" * 16
decryptor = Cipher(algorithms.AES(default_key), modes.CBC(iv)).decryptor()
decryptor = Cipher(algorithms.AES(default_key), modes.CBC(iv.encode())).decryptor()
encryption_key = decryptor.update(self.encrypted_key) + decryptor.finalize()

return encryption_key
Expand Down
2 changes: 1 addition & 1 deletion pysap/SAPPSE.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def decrypt_non_lps(self, pin):
# On version 2, we can check that the PIN was valid before decrypting the whole
# cipher text
if self.version == 2:
encrypted_pin = pbes.encrypt(pin)
encrypted_pin = pbes.encrypt(pin.encode())
if encrypted_pin != self.enc_cont.encrypted_pin.val:
raise ValueError("Invalid PIN supplied")

Expand Down
17 changes: 9 additions & 8 deletions pysap/SAPRouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,15 @@ def from_hops(cls, route_hops):
:return: route string
:rtype: C{string}
"""
result = ""
for route_hop in route_hops:
result += "/H/{}".format(route_hop.hostname)
if route_hop.port:
result += "/S/{}".format(route_hop.port)
if route_hop.password:
result += "/W/{}".format(route_hop.password)
return result
route_string = ""
for hop in route_hops:
if hop.hostname:
route_string += f"/H/{hop.hostname}"
if hop.port:
route_string += f"/S/{hop.port}"
if hop.password:
route_string += f"/W/{hop.password}"
return route_string


class SAPRouterInfoClient(PacketNoPadded):
Expand Down
18 changes: 10 additions & 8 deletions pysap/SAPSSFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def valid(self):
blob = str(self)

digest = Hash(SHA1())
digest.update(blob[:8])
digest.update(blob[8:8+4])
digest.update(blob[:8].encode())
digest.update(blob[8:8+4].encode())
if self.length:
digest.update(blob[0x20:0x20 + self.length])
digest.update(blob[0x20:0x20 + self.length].encode())
if len(blob) > self.length + 0x20:
digest.update(blob[0x20 + self.length:])
digest.update(blob[0x20 + self.length:].encode())
blob_hash = digest.finalize()
return blob_hash == self.hash

Expand Down Expand Up @@ -158,8 +158,8 @@ def valid(self):
"""Returns whether the HMAC-SHA1 value is valid for the given payload"""

# Calculate the HMAC-SHA1
h = HMAC(ssfs_hmac_key_unobscured, SHA1())
h.update(str(self)[24:156]) # Entire Data header without the HMAC field
h = HMAC(ssfs_hmac_key_unobscured.encode(), SHA1())
h.update((self)[24:156]) # Entire Data header without the HMAC field
h.update(self.data)

# Validate the signature
Expand Down Expand Up @@ -194,8 +194,9 @@ def has_record(self, key_name):
:return: if the data file contains the record with key_name
:rtype: bool
"""
key_name_bytes = key_name.encode('utf-8') # Convert input string to bytes
for record in self.records:
if record.key_name.rstrip(" ") == key_name:
if record.key_name.rstrip(b" ") == key_name_bytes: # Use b" " for byte string
return True
return False

Expand All @@ -208,8 +209,9 @@ def get_records(self, key_name):
:return: the record with key_name
:rtype: SAPSSFSDataRecord
"""
key_name_bytes = key_name.encode('utf-8') # Convert input string to bytes
for record in self.records:
if record.key_name.rstrip(" ") == key_name:
if record.key_name.rstrip(b" ") == key_name_bytes:
yield record

def get_record(self, key_name):
Expand Down
41 changes: 19 additions & 22 deletions pysap/utils/crypto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def derive(self, password):
v = self._algorithm.block_size

# Step 1 - Concatenate v/8 copies of ID
d = chr(self._id) * v
d = bytes([self._id]) * v

def concatenate_string(inp):
s = b''
Expand All @@ -135,7 +135,6 @@ def concatenate_string(inp):
c = int(math.ceil(float(self._length) / u))

# Step 6

def digest(inp):
h = Hash(self._algorithm())
h.update(inp)
Expand All @@ -144,17 +143,17 @@ def digest(inp):
def to_int(value):
if value == b'':
return 0
return int(value.encode("hex"), 16)
return int.from_bytes(value, byteorder='big')

def to_bytes(value):
value = "%x" % value
if len(value) & 1:
value = "0" + value
return value.decode("hex")
def to_bytes(value, length):
try:
return value.to_bytes(length, byteorder='big')
except OverflowError:
# If the integer is too large, we'll take the least significant bytes
return (value & ((1 << (8 * length)) - 1)).to_bytes(length, byteorder='big')

a = b'\x00' * (c * u)
for n in range(1, c + 1):

a2 = digest(d + i)
for _ in range(2, self._iterations + 1):
a2 = digest(a2)
Expand All @@ -172,13 +171,9 @@ def to_bytes(value):
start = n2 * v
end = (n2 + 1) * v
i_n2 = i[start:end]
i_n2 = to_bytes(to_int(i_n2) + b)

i_n2_l = len(i_n2)
if i_n2_l > v:
i_n2 = i_n2[i_n2_l - v:]
i_n2 = to_bytes(to_int(i_n2) + b, v)

i = i[0:start] + i_n2 + i[end:]
i = i[:start] + i_n2 + i[end:]

# Step 7
start = (n - 1) * u
Expand Down Expand Up @@ -230,6 +225,8 @@ def __init__(self, salt, iterations, iv, password, hash_algorithm, enc_algorithm
self._derive_key, self._iv = self.derive_key(salt, iterations, password)

def derive_key(self, salt, iterations, password):
if isinstance(password, str):
password = password.encode()
pkcs12_pbkdf1 = PKCS12_PBKDF1(self._hash_algorithm, 24, salt, iterations, 1)
key = pkcs12_pbkdf1.derive(password)

Expand All @@ -248,11 +245,11 @@ def encrypt(self, plain_text):
return cipher_text

def decrypt(self, cipher_text):
padder = padding.PKCS7(self._hash_algorithm.block_size).padder()
cipher_text = padder.update(cipher_text) + padder.finalize()

decryptor = Cipher(self._enc_algorithm(self._derive_key), self._enc_mode(self._iv)).decryptor()
plain_text = decryptor.update(cipher_text) + decryptor.finalize()
padded_plain_text = decryptor.update(cipher_text) + decryptor.finalize()

unpadder = padding.PKCS7(self._hash_algorithm.block_size).unpadder()
plain_text = unpadder.update(padded_plain_text) + unpadder.finalize()

return plain_text

Expand Down Expand Up @@ -348,8 +345,8 @@ def rsec_decrypt(blob, key):
if len(key) != 24:
raise Exception("Wrong key length")

blob = [ord(i) for i in blob]
key = [ord(i) for i in key]
blob = [i for i in blob]
key = [i for i in key]
key1 = key[0:8]
key2 = key[8:16]
key3 = key[16:24]
Expand All @@ -359,4 +356,4 @@ def rsec_decrypt(blob, key):
round_2 = cipher.crypt(RSECCipher.MODE_ENCODE, round_1, key2, len(round_1))
round_3 = cipher.crypt(RSECCipher.MODE_DECODE, round_2, key1, len(round_2))

return ''.join([chr(i) for i in round_3])
return bytes(round_3)
4 changes: 2 additions & 2 deletions tests/test_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_scram_sha256_scramble_salt(self):
Values are taken from https://github.com/SAP/PyHDB/blob/master/tests/test_auth.py
"""

password = "secret"
password = b"secret"
salt = b"\x80\x96\x4f\xa8\x54\x28\xae\x3a\x81\xac" \
b"\xd3\xe6\x86\xa2\x79\x33"
server_key = b"\x41\x06\x51\x50\x11\x7e\x45\x5f\xec\x2f\x03\xf6" \
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_scram_pbkdf2sha256_scramble_salt(self):
Values are taken from https://github.com/SAP/go-hdb/blob/master/internal/protocol/authentication_test.go
"""

password = "Toor1234"
password = b"Toor1234"
rounds = 15000
salt = b"3\xb2\xd5\xd5\\R\xc2(Px\xc5[\xa6C\x17?"
server_key = b" [\xa5\x12\x9eM\x86E\x80\x9dE\xd1/!\xab\xa48\xac\xe5\x00\x99\x03A\x1d\xef\xd2\xba\x86Q \x1d\x89\xef\xa7'\x01\xabuU\x8am&*M+*RF"
Expand Down
31 changes: 23 additions & 8 deletions tests/test_sapcredv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
class PySAPCredV2Test(unittest.TestCase):

decrypt_username = "username"
decrypt_pin = "1234567890"
decrypt_pin = b"1234567890"
cert_name = "CN=PSEOwner"
common_name = "PSEOwner"
subject_str = "/CN=PSEOwner"
Expand All @@ -49,17 +49,32 @@ def validate_credv2_lps_off_fields(self, creds, number, lps_type, cipher_format_
self.assertEqual(len(creds), number)
cred = creds[0].cred

self.assertEqual(cred.common_name, cert_name or self.cert_name)
self.assertEqual(cred.pse_file_path, pse_path or self.pse_path)
# Check if cert_name is None before encoding
cert_name_encoded = cert_name.encode() if cert_name is not None else None

# Check if pse_path is None before encoding
pse_path_encoded = pse_path.encode() if pse_path is not None else None

# Assert common_name (previously cert_name)
self.assertEqual(cred.common_name, cert_name_encoded or self.cert_name.encode())

# Assert pse_file_path (previously pse_path)
self.assertEqual(cred.pse_file_path, pse_path_encoded or self.pse_path.encode())

# These assertions remain unchanged
self.assertEqual(cred.lps_type, lps_type)
self.assertEqual(cred.cipher_format_version, cipher_format_version)
self.assertEqual(cred.cipher_algorithm, cipher_algorithm)

self.assertEqual(cred.cert_name, cert_name or self.cert_name)
self.assertEqual(cred.unknown1, "")
self.assertEqual(cred.pse_path, pse_path or self.pse_path)
self.assertEqual(cred.unknown2, "")
# Assert cert_name (previously common_name)
self.assertEqual(cred.cert_name, cert_name_encoded or self.cert_name.encode())

self.assertEqual(cred.unknown1, b"")

# Assert pse_path (previously pse_file_path)
self.assertEqual(cred.pse_path, pse_path_encoded or self.pse_path.encode())

self.assertEqual(cred.unknown2, b"")
def validate_credv2_plain(self, cred, decrypt_username=None, decrypt_pin=None):
plain = cred.decrypt(decrypt_username or self.decrypt_username)
self.assertEqual(plain.pin.val, decrypt_pin or self.decrypt_pin)
Expand Down Expand Up @@ -200,7 +215,7 @@ def test_credv2_lps_on_v2_int_aes256_composed_subject(self):
subject = [
X509_RDN(rdn=[
X509_AttributeTypeAndValue(type=ASN1_OID("2.5.4.6"),
value=ASN1_PRINTABLE_STRING("AR"))]),
value=ASN1_PRINTABLE_STRING(b"AR"))]),
X509_RDN(rdn=[
X509_AttributeTypeAndValue(type=ASN1_OID("2.5.4.3"),
value=ASN1_PRINTABLE_STRING(self.common_name))]),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sapdiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_sapdiag_header_dissection_plain(self):

diag_header_plain = SAPDiag(compress=0)
diag_header_plain.message.append(diag_item)
new_diag_header_plain = SAPDiag(str(diag_header_plain))
new_diag_header_plain = SAPDiag(diag_header_plain)

self.assertEqual(str(diag_header_plain),
str(new_diag_header_plain))
Expand Down Expand Up @@ -148,9 +148,9 @@ class SAPDiagItemTest(Packet):
fields_desc = [StrField("strfield", None)]
bind_diagitem(SAPDiagItemTest, "APPL", 0x99, 0xff)

item_string = "strfield"
item_string = b"strfield"
item_value = SAPDiagItemTest(strfield=item_string)
item = SAPDiagItem("\x10\x99\xff" + pack("!H", len(item_string)) + item_string)
item = SAPDiagItem(b"\x10\x99\xff" + pack("!H", len(item_string)) + item_string)

self.assertEqual(item.item_value, item_value)
self.assertEqual(item.item_length, len(item_string))
Expand Down
Loading

0 comments on commit b02d6e8

Please sign in to comment.