diff --git a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java index ca29a00..03c30d6 100644 --- a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java +++ b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java @@ -3,6 +3,8 @@ package com.lambdaworks.crypto; import java.io.UnsupportedEncodingException; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; import java.security.GeneralSecurityException; import java.security.SecureRandom; @@ -28,6 +30,12 @@ * @author Will Glozer */ public class SCryptUtil { + // for timedIterations() + private static final byte[] BENCH_PASSWD = "secret".getBytes(); + private static final byte[] BENCH_SALT = "1234".getBytes(); + private static final int BENCH_DK_LEN = 32; + private static final int BENCH_INITIAL_N = 64; + /** * Hash the supplied plaintext password and generate output in the format described * in {@link SCryptUtil}. @@ -109,4 +117,59 @@ private static int log2(int n) { if (n >= 4 ) { n >>>= 2; log += 2; } return log + (n >>> 1); } + + /** + * Determines a CPU cost value (i.e. a value for the N parameter) that will cause password + * verification to take (roughly) a given time on the current CPU for the specified + * r and p values.
+ * N is rounded to the nearest power of two because only powers of two are valid + * choices for N. The actual time spent will be between about .7*milliseconds + * and 1.4*milliseconds. + * + * @param milliseconds the time scrypt should spend verifying a password + * @param r memory cost parameter + * @param p parallelization parameter + * + * @return a value for N such that scrypt(N, r, p) runs for roughly milliseconds + * + * @throws GeneralSecurityException when HMAC_SHA256 is not available. + */ + public static int timedIterations(int milliseconds, int r, int p) throws GeneralSecurityException { + ThreadMXBean threadBean = ManagementFactory.getThreadMXBean(); + boolean cpuTimeSupported = threadBean.isCurrentThreadCpuTimeSupported(); + boolean origEnabledFlag = false; + if (cpuTimeSupported) { + origEnabledFlag = threadBean.isThreadCpuTimeEnabled(); + if (!origEnabledFlag) + threadBean.setThreadCpuTimeEnabled(true); + } + + int N = BENCH_INITIAL_N; + long lastDelta = 0; + while (true) { + // prefer CPU time over real world time so the result is load independent + long startTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime(); + SCrypt.scrypt(BENCH_PASSWD, BENCH_SALT, N, r, p, BENCH_DK_LEN); + long endTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime(); + long delta = (endTime-startTime) / 1000000; + + // start over if a speed increase is detected due to the code being JITted + if (delta < lastDelta) { + N = BENCH_INITIAL_N; + lastDelta = 0; + continue; + } + + if (delta > milliseconds) { + if (cpuTimeSupported) + threadBean.setThreadCpuTimeEnabled(origEnabledFlag); + // round to the nearest power of two + if (delta-delta/4 > milliseconds) + N /= 2; + return N; + } + N *= 2; + lastDelta = delta; + } + } } diff --git a/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java b/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java index f673657..c81723d 100644 --- a/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java +++ b/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java @@ -3,7 +3,12 @@ package com.lambdaworks.crypto.test; import com.lambdaworks.codec.Base64; +import com.lambdaworks.crypto.SCrypt; import com.lambdaworks.crypto.SCryptUtil; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.security.GeneralSecurityException; +import java.util.Random; import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.*; @@ -56,4 +61,27 @@ public void format_0_rp_max() throws Exception { assertEquals(r, params >> 8 & 0xff); assertEquals(p, params >> 0 & 0xff); } + + @Test + public void testTimedIterations() throws GeneralSecurityException { + byte[] salt = "1234".getBytes(); + int dkLen = 32; + + ThreadMXBean threadBean = ManagementFactory.getThreadMXBean(); + boolean cpuTimeSupported = threadBean.isCurrentThreadCpuTimeSupported(); + Random random = new Random(); + for (int i=0; i<5; i++) { + int targetDuration = 100 + random.nextInt(900); + int numIterations = SCryptUtil.timedIterations(targetDuration, 8, 1); + long startTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime(); + SCrypt.scrypt(passwd.getBytes(), salt, numIterations, 8, 1, dkLen); + long endTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime(); + long actualDuration = (endTime-startTime) / 1000000; + + // check that actual duration is within targetDuration - 50% and targetDuration + 60% + String failMessage = "Target duration=" + targetDuration + ", actual=" + actualDuration; + assertTrue(failMessage, actualDuration > targetDuration*0.5); + assertTrue(failMessage, actualDuration < targetDuration*1.6); + } + } }