Skip to content
This repository has been archived by the owner on Sep 20, 2023. It is now read-only.

Commit

Permalink
Merge pull request #312 from ocheron/eddsa-gen
Browse files Browse the repository at this point in the history
Generic EdDSA implementation
  • Loading branch information
vincenthz authored Apr 14, 2021
2 parents f449a54 + 981b97a commit cf89276
Show file tree
Hide file tree
Showing 9 changed files with 651 additions and 51 deletions.
24 changes: 12 additions & 12 deletions Crypto/ECC/Edwards25519.hs
Original file line number Diff line number Diff line change
Expand Up @@ -283,45 +283,45 @@ pointsMulVarTime (Scalar s1) (Scalar s2) (Point p) =
withByteArray p $ \pp ->
ed25519_base_double_scalarmul_vartime out ps1 pp ps2

foreign import ccall "cryptonite_ed25519_scalar_eq"
foreign import ccall unsafe "cryptonite_ed25519_scalar_eq"
ed25519_scalar_eq :: Ptr Scalar
-> Ptr Scalar
-> IO CInt

foreign import ccall "cryptonite_ed25519_scalar_encode"
foreign import ccall unsafe "cryptonite_ed25519_scalar_encode"
ed25519_scalar_encode :: Ptr Word8
-> Ptr Scalar
-> IO ()

foreign import ccall "cryptonite_ed25519_scalar_decode_long"
foreign import ccall unsafe "cryptonite_ed25519_scalar_decode_long"
ed25519_scalar_decode_long :: Ptr Scalar
-> Ptr Word8
-> CSize
-> IO ()

foreign import ccall "cryptonite_ed25519_scalar_add"
foreign import ccall unsafe "cryptonite_ed25519_scalar_add"
ed25519_scalar_add :: Ptr Scalar -- sum
-> Ptr Scalar -- a
-> Ptr Scalar -- b
-> IO ()

foreign import ccall "cryptonite_ed25519_scalar_mul"
foreign import ccall unsafe "cryptonite_ed25519_scalar_mul"
ed25519_scalar_mul :: Ptr Scalar -- out
-> Ptr Scalar -- a
-> Ptr Scalar -- b
-> IO ()

foreign import ccall "cryptonite_ed25519_point_encode"
foreign import ccall unsafe "cryptonite_ed25519_point_encode"
ed25519_point_encode :: Ptr Word8
-> Ptr Point
-> IO ()

foreign import ccall "cryptonite_ed25519_point_decode_vartime"
foreign import ccall unsafe "cryptonite_ed25519_point_decode_vartime"
ed25519_point_decode_vartime :: Ptr Point
-> Ptr Word8
-> IO CInt

foreign import ccall "cryptonite_ed25519_point_eq"
foreign import ccall unsafe "cryptonite_ed25519_point_eq"
ed25519_point_eq :: Ptr Point
-> Ptr Point
-> IO CInt
Expand All @@ -330,23 +330,23 @@ foreign import ccall "cryptonite_ed25519_point_has_prime_order"
ed25519_point_has_prime_order :: Ptr Point
-> IO CInt

foreign import ccall "cryptonite_ed25519_point_negate"
foreign import ccall unsafe "cryptonite_ed25519_point_negate"
ed25519_point_negate :: Ptr Point -- minus_a
-> Ptr Point -- a
-> IO ()

foreign import ccall "cryptonite_ed25519_point_add"
foreign import ccall unsafe "cryptonite_ed25519_point_add"
ed25519_point_add :: Ptr Point -- sum
-> Ptr Point -- a
-> Ptr Point -- b
-> IO ()

foreign import ccall "cryptonite_ed25519_point_double"
foreign import ccall unsafe "cryptonite_ed25519_point_double"
ed25519_point_double :: Ptr Point -- two_a
-> Ptr Point -- a
-> IO ()

foreign import ccall "cryptonite_ed25519_point_mul_by_cofactor"
foreign import ccall unsafe "cryptonite_ed25519_point_mul_by_cofactor"
ed25519_point_mul_by_cofactor :: Ptr Point -- eight_a
-> Ptr Point -- a
-> IO ()
Expand Down
53 changes: 53 additions & 0 deletions Crypto/Internal/Builder.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
-- |
-- Module : Crypto.Internal.Builder
-- License : BSD-style
-- Maintainer : Olivier Chéron <[email protected]>
-- Stability : stable
-- Portability : Good
--
-- Delaying and merging ByteArray allocations. This is similar to module
-- "Data.ByteArray.Pack" except the total length is computed automatically based
-- on what is appended.
--
{-# LANGUAGE BangPatterns #-}
module Crypto.Internal.Builder
( Builder
, buildAndFreeze
, builderLength
, byte
, bytes
, zero
) where

import Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B
import Data.Memory.PtrMethods (memSet)

import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (poke)

import Crypto.Internal.Imports hiding (empty)

data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer

instance Semigroup Builder where
(Builder s1 f1) <> (Builder s2 f2) = Builder (s1 + s2) f
where f p = f1 p >> f2 (p `plusPtr` s1)

builderLength :: Builder -> Int
builderLength (Builder s _) = s

buildAndFreeze :: ByteArray ba => Builder -> ba
buildAndFreeze (Builder s f) = B.allocAndFreeze s f

byte :: Word8 -> Builder
byte !b = Builder 1 (`poke` b)

bytes :: ByteArrayAccess ba => ba -> Builder
bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs)

zero :: Int -> Builder
zero s = if s > 0 then Builder s (\p -> memSet p 0 s) else empty

empty :: Builder
empty = Builder 0 (const $ return ())
4 changes: 4 additions & 0 deletions Crypto/Internal/Imports.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
-- Stability : experimental
-- Portability : unknown
--
{-# LANGUAGE CPP #-}
module Crypto.Internal.Imports
( module X
) where

import Data.Word as X
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup as X (Semigroup(..))
#endif
import Control.Applicative as X
import Control.Monad as X (forM, forM_, void)
import Control.Arrow as X (first, second)
Expand Down
53 changes: 14 additions & 39 deletions Crypto/MAC/KMAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ import qualified Crypto.Hash as H
import Crypto.Hash.SHAKE (HashSHAKE(..))
import Crypto.Hash.Types (HashAlgorithm(..), Digest(..))
import qualified Crypto.Hash.Types as H
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (poke)
import Crypto.Internal.Builder
import Crypto.Internal.Imports
import Foreign.Ptr (Ptr)
import Data.Bits (shiftR)
import Data.ByteArray (ByteArray, ByteArrayAccess)
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as B
import Data.Word (Word8)
import Data.Memory.PtrMethods (memSet)


-- cSHAKE
Expand All @@ -47,8 +46,8 @@ cshakeInit n s p = H.Context $ B.allocAndFreeze c $ \(ptr :: Ptr (H.Context a))
where
c = hashInternalContextSize (undefined :: a)
w = hashBlockSize (undefined :: a)
x = encodeString n <+> encodeString s
b = builderAllocAndFreeze (bytepad x w) :: B.Bytes
x = encodeString n <> encodeString s
b = buildAndFreeze (bytepad x w) :: B.Bytes

cshakeUpdate :: (HashSHAKE a, ByteArrayAccess ba)
=> H.Context a -> ba -> H.Context a
Expand Down Expand Up @@ -77,7 +76,7 @@ cshakeFinalize !c s =
-- The Eq instance is constant time. No Show instance is provided, to avoid
-- printing by mistake.
newtype KMAC a = KMAC { kmacGetDigest :: Digest a }
deriving ByteArrayAccess
deriving (ByteArrayAccess,NFData)

instance Eq (KMAC a) where
(KMAC b1) == (KMAC b2) = B.constEq b1 b2
Expand All @@ -99,7 +98,7 @@ initialize str key = Context $ cshakeInit n str p
where
n = B.pack [75,77,65,67] :: B.Bytes -- "KMAC"
w = hashBlockSize (undefined :: a)
p = builderAllocAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes
p = buildAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes

-- | Incrementally update a KMAC context.
update :: (HashSHAKE a, ByteArrayAccess ba) => Context a -> ba -> Context a
Expand All @@ -114,56 +113,32 @@ finalize :: forall a . HashSHAKE a => Context a -> KMAC a
finalize (Context ctx) = KMAC $ cshakeFinalize ctx suffix
where
l = cshakeOutputLength (undefined :: a)
suffix = builderAllocAndFreeze (rightEncode l) :: B.Bytes
suffix = buildAndFreeze (rightEncode l) :: B.Bytes


-- Utilities

bytepad :: Builder -> Int -> Builder
bytepad x w = prefix <+> x <+> zero padLen
bytepad x w = prefix <> x <> zero padLen
where
prefix = leftEncode w
padLen = (w - builderLength prefix - builderLength x) `mod` w

encodeString :: ByteArrayAccess bin => bin -> Builder
encodeString s = leftEncode (8 * B.length s) <+> bytes s
encodeString s = leftEncode (8 * B.length s) <> bytes s

leftEncode :: Int -> Builder
leftEncode x = byte len <+> digits
leftEncode x = byte len <> digits
where
digits = i2osp x
len = fromIntegral (builderLength digits)

rightEncode :: Int -> Builder
rightEncode x = digits <+> byte len
rightEncode x = digits <> byte len
where
digits = i2osp x
len = fromIntegral (builderLength digits)

i2osp :: Int -> Builder
i2osp i | i >= 256 = i2osp (shiftR i 8) <+> byte (fromIntegral i)
i2osp i | i >= 256 = i2osp (shiftR i 8) <> byte (fromIntegral i)
| otherwise = byte (fromIntegral i)


-- Delaying and merging ByteArray allocations

data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer

(<+>) :: Builder -> Builder -> Builder
(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f
where f p = f1 p >> f2 (p `plusPtr` s1)

builderLength :: Builder -> Int
builderLength (Builder s _) = s

builderAllocAndFreeze :: ByteArray ba => Builder -> ba
builderAllocAndFreeze (Builder s f) = B.allocAndFreeze s f

byte :: Word8 -> Builder
byte !b = Builder 1 (`poke` b)

bytes :: ByteArrayAccess ba => ba -> Builder
bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs)

zero :: Int -> Builder
zero s = Builder s (\p -> memSet p 0 s)
Loading

0 comments on commit cf89276

Please sign in to comment.