Skip to content

Commit

Permalink
adds chacha and salsa variants
Browse files Browse the repository at this point in the history
  • Loading branch information
huettenhain committed Oct 22, 2024
1 parent 0a70030 commit 66b1559
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 8 deletions.
37 changes: 35 additions & 2 deletions refinery/units/crypto/cipher/chacha.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from Cryptodome.Cipher import ChaCha20
from Cryptodome.Cipher import ChaCha20, ChaCha20_Poly1305
from typing import List, Iterable

from refinery.units.crypto.cipher.salsa import LatinCipher
import struct

from refinery.units.crypto.cipher.salsa import LatinCipher, LatinX
from refinery.units.crypto.cipher import LatinCipherUnit, LatinCipherStandardUnit
from refinery.lib.crypto import rotl32, PyCryptoFactoryWrapper

Expand Down Expand Up @@ -44,6 +46,18 @@ class chacha20(LatinCipherStandardUnit, cipher=PyCryptoFactoryWrapper(ChaCha20))
pass


class chacha20poly1305(LatinCipherStandardUnit, cipher=PyCryptoFactoryWrapper(ChaCha20_Poly1305)):
"""
ChaCha20-Poly1305 and XChaCha20-Poly1305 encryption and decryption. For the ChaCha20
variant, the nonce must be 8 or 12 bytes long; for XChaCha20, provide a 24 bytes nonce
instead.
"""
def _get_cipher(self, reset_cache=False):
cipher = super()._get_cipher(reset_cache)
cipher.block_size = 1
return cipher


class chacha(LatinCipherUnit):
"""
ChaCha encryption and decryption. The nonce must be 8 bytes long as currently, only the
Expand All @@ -63,3 +77,22 @@ def keystream(self) -> Iterable[int]:
self.args.offset,
)
yield from it


class xchacha(LatinCipherUnit):
"""
XChaCha encryption and decryption. The nonce must be 24 bytes long.
"""
def keystream(self) -> Iterable[int]:
kdp, kdn, nonce = struct.unpack('<Q8s8s', self.args.nonce)
yield from LatinX(
ChaChaCipher,
(0, 1, 2, 3, 12, 13, 14, 15),
self.args.key,
kdn,
kdp,
nonce,
self.args.magic,
self.args.rounds,
self.args.offset,
)
56 changes: 50 additions & 6 deletions refinery/units/crypto/cipher/salsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from Cryptodome.Cipher import Salsa20
from abc import ABC, abstractmethod
from typing import List, ByteString, Union, Sequence, Optional, Iterable, Tuple
from typing import List, Union, Sequence, Optional, Iterable, Tuple, Type, TypeVar

from refinery.units.crypto.cipher import LatinCipherUnit, LatinCipherStandardUnit
from refinery.lib.crypto import rotl32, PyCryptoFactoryWrapper
from refinery.lib.types import ByteStr


class LatinCipher(ABC):
Expand All @@ -19,7 +20,7 @@ class LatinCipher(ABC):
_round_access_pattern: Tuple[Tuple[int, int, int, int], ...]

@classmethod
def FromState(cls, state: Union[Sequence[int], ByteString]):
def FromState(cls, state: Union[Sequence[int], ByteStr]):
try:
state = struct.unpack('<16L', state)
except TypeError:
Expand All @@ -37,9 +38,9 @@ def FromState(cls, state: Union[Sequence[int], ByteString]):
'<2L', *state[cls._idx_count]), 'little')
return cls(key, nonce, magic, counter=count)

def __init__(self, key: ByteString, nonce: ByteString, magic: Optional[ByteString] = None, rounds: int = 20, counter: int = 0):
def __init__(self, key: ByteStr, nonce: ByteStr, magic: Optional[ByteStr] = None, rounds: int = 20, counter: int = 0):
if len(key) == 16:
key += key
key = 2 * key
elif len(key) != 32:
raise ValueError('The key must be of length 16 or 32.')
if rounds % 2:
Expand Down Expand Up @@ -85,12 +86,15 @@ def count(self):
def quarter(self, x: List[int], a: int, b: int, c: int, d: int):
raise NotImplementedError

def permute(self, x: List[int]):
for a, b, c, d in self.rounds * self._round_access_pattern:
self.quarter(x, a, b, c, d)

def __iter__(self):
x = [0] * len(self.state)
while True:
x[:] = self.state
for a, b, c, d in self.rounds * self._round_access_pattern:
self.quarter(x, a, b, c, d)
self.permute(x)
yield from struct.pack('<16L', *(
(a + b) & 0xFFFFFFFF for a, b in zip(x, self.state)))
self.count()
Expand Down Expand Up @@ -121,6 +125,27 @@ def quarter(x: List[int], a: int, b: int, c: int, d: int) -> None:
x[a] ^= rotl32(x[d] + x[c] & 0xFFFFFFFF, 0x12)


_X = TypeVar('_X', bound=LatinCipher)


def LatinX(
cipher: Type[_X],
blocks: Iterable[int],
key: ByteStr,
kdn: ByteStr,
kdp: ByteStr,
nonce: ByteStr,
magic: ByteStr,
rounds: int,
offset: int,
) -> _X:
from refinery.lib import chunks
kd = cipher(key, kdn, magic, rounds, kdp)
kd.permute(kd.state)
key = chunks.pack((kd.state[i] for i in blocks), 4)
return cipher(key, nonce, magic, rounds, offset)


class salsa(LatinCipherUnit):
"""
Salsa encryption and decryption. The nonce must be 8 bytes long. When 64 bytes are provided
Expand All @@ -142,6 +167,25 @@ def keystream(self) -> Iterable[int]:
yield from it


class xsalsa(LatinCipherUnit):
"""
XSalsa encryption and decryption. The nonce must be 24 bytes long.
"""
def keystream(self) -> Iterable[int]:
kdn, kdp, nonce = struct.unpack('<8sQ8s', self.args.nonce)
yield from LatinX(
SalsaCipher,
(0, 5, 10, 15, 6, 7, 8, 9),
self.args.key,
kdn,
kdp,
nonce,
self.args.magic,
self.args.rounds,
self.args.offset,
)


class salsa20(LatinCipherStandardUnit, cipher=PyCryptoFactoryWrapper(Salsa20)):
"""
Salsa20 encryption and decryption. This unit is functionally equivalent to `refinery.salsa`
Expand Down
37 changes: 37 additions & 0 deletions test/units/crypto/cipher/test_chacha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from ... import TestUnitBase


class TestChaCha(TestUnitBase):

def test_xchacha(self):
key = bytes.fromhex(
'808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f')
nonce = bytes.fromhex(
'404142434445464748494a4b4c4d4e4f5051525354555658')
data = bytes.fromhex(
'5468652064686f6c65202870726f6e6f756e6365642022646f6c652229206973'
'20616c736f206b6e6f776e2061732074686520417369617469632077696c6420'
'646f672c2072656420646f672c20616e642077686973746c696e6720646f672e'
'2049742069732061626f7574207468652073697a65206f662061204765726d61'
'6e20736865706865726420627574206c6f6f6b73206d6f7265206c696b652061'
'206c6f6e672d6c656767656420666f782e205468697320686967686c7920656c'
'757369766520616e6420736b696c6c6564206a756d70657220697320636c6173'
'736966696564207769746820776f6c7665732c20636f796f7465732c206a6163'
'6b616c732c20616e6420666f78657320696e20746865207461786f6e6f6d6963'
'2066616d696c792043616e696461652e')
goal = bytes.fromhex(
'7d0a2e6b7f7c65a236542630294e063b7ab9b555a5d5149aa21e4ae1e4fbce87'
'ecc8e08a8b5e350abe622b2ffa617b202cfad72032a3037e76ffdcdc4376ee05'
'3a190d7e46ca1de04144850381b9cb29f051915386b8a710b8ac4d027b8b050f'
'7cba5854e028d564e453b8a968824173fc16488b8970cac828f11ae53cabd201'
'12f87107df24ee6183d2274fe4c8b1485534ef2c5fbc1ec24bfc3663efaa08bc'
'047d29d25043532db8391a8a3d776bf4372a6955827ccb0cdd4af403a7ce4c63'
'd595c75a43e045f0cce1f29c8b93bd65afc5974922f214a40b7c402cdb91ae73'
'c0b63615cdad0480680f16515a7ace9d39236464328a37743ffc28f4ddb324f4'
'd0f5bbdc270c65b1749a6efff1fbaa09536175ccd29fb9e6057b307320d31683'
'8a9c71f70b5b5907a66f7ea49aadc409'
)
test = data | self.ldu('xchacha', offset=1, key=key, nonce=nonce) | bytes
self.assertEqual(test, goal)
28 changes: 28 additions & 0 deletions test/units/crypto/cipher/test_salsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from ... import TestUnitBase


class TestSalsa(TestUnitBase):

def test_xsalsa(self):
key = B'this is 32-byte key for xsalsa20'
nonce = B'24-byte nonce for xsalsa'
data = [B'Hello world!', bytearray(64)]
goal = [
bytes((
0x00, 0x2d, 0x45, 0x13, 0x84, 0x3f, 0xc2, 0x40,
0xc4, 0x01, 0xe5, 0x41)),
bytes((
0x48, 0x48, 0x29, 0x7f, 0xeb, 0x1f, 0xb5, 0x2f,
0xb6, 0x6d, 0x81, 0x60, 0x9b, 0xd5, 0x47, 0xfa,
0xbc, 0xbe, 0x70, 0x26, 0xed, 0xc8, 0xb5, 0xe5,
0xe4, 0x49, 0xd0, 0x88, 0xbf, 0xa6, 0x9c, 0x08,
0x8f, 0x5d, 0x8d, 0xa1, 0xd7, 0x91, 0x26, 0x7c,
0x2c, 0x19, 0x5a, 0x7f, 0x8c, 0xae, 0x9c, 0x4b,
0x40, 0x50, 0xd0, 0x8c, 0xe6, 0xd3, 0xa1, 0x51,
0xec, 0x26, 0x5f, 0x3a, 0x58, 0xe4, 0x76, 0x48))
]
for d, g in zip(data, goal):
test = d | self.ldu('xsalsa', key=key, nonce=nonce) | bytes
self.assertEqual(test, g)

0 comments on commit 66b1559

Please sign in to comment.