diff --git a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java
index ca29a00..03c30d6 100644
--- a/src/main/java/com/lambdaworks/crypto/SCryptUtil.java
+++ b/src/main/java/com/lambdaworks/crypto/SCryptUtil.java
@@ -3,6 +3,8 @@
package com.lambdaworks.crypto;
import java.io.UnsupportedEncodingException;
+import java.lang.management.ManagementFactory;
+import java.lang.management.ThreadMXBean;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
@@ -28,6 +30,12 @@
* @author Will Glozer
*/
public class SCryptUtil {
+ // for timedIterations()
+ private static final byte[] BENCH_PASSWD = "secret".getBytes();
+ private static final byte[] BENCH_SALT = "1234".getBytes();
+ private static final int BENCH_DK_LEN = 32;
+ private static final int BENCH_INITIAL_N = 64;
+
/**
* Hash the supplied plaintext password and generate output in the format described
* in {@link SCryptUtil}.
@@ -109,4 +117,59 @@ private static int log2(int n) {
if (n >= 4 ) { n >>>= 2; log += 2; }
return log + (n >>> 1);
}
+
+ /**
+ * Determines a CPU cost value (i.e. a value for the N parameter) that will cause password
+ * verification to take (roughly) a given time on the current CPU for the specified
+ * r
and p
values.
+ * N is rounded to the nearest power of two because only powers of two are valid
+ * choices for N. The actual time spent will be between about .7*milliseconds
+ * and 1.4*milliseconds
.
+ *
+ * @param milliseconds the time scrypt should spend verifying a password
+ * @param r memory cost parameter
+ * @param p parallelization parameter
+ *
+ * @return a value for N such that scrypt(N, r, p)
runs for roughly milliseconds
+ *
+ * @throws GeneralSecurityException when HMAC_SHA256 is not available.
+ */
+ public static int timedIterations(int milliseconds, int r, int p) throws GeneralSecurityException {
+ ThreadMXBean threadBean = ManagementFactory.getThreadMXBean();
+ boolean cpuTimeSupported = threadBean.isCurrentThreadCpuTimeSupported();
+ boolean origEnabledFlag = false;
+ if (cpuTimeSupported) {
+ origEnabledFlag = threadBean.isThreadCpuTimeEnabled();
+ if (!origEnabledFlag)
+ threadBean.setThreadCpuTimeEnabled(true);
+ }
+
+ int N = BENCH_INITIAL_N;
+ long lastDelta = 0;
+ while (true) {
+ // prefer CPU time over real world time so the result is load independent
+ long startTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
+ SCrypt.scrypt(BENCH_PASSWD, BENCH_SALT, N, r, p, BENCH_DK_LEN);
+ long endTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
+ long delta = (endTime-startTime) / 1000000;
+
+ // start over if a speed increase is detected due to the code being JITted
+ if (delta < lastDelta) {
+ N = BENCH_INITIAL_N;
+ lastDelta = 0;
+ continue;
+ }
+
+ if (delta > milliseconds) {
+ if (cpuTimeSupported)
+ threadBean.setThreadCpuTimeEnabled(origEnabledFlag);
+ // round to the nearest power of two
+ if (delta-delta/4 > milliseconds)
+ N /= 2;
+ return N;
+ }
+ N *= 2;
+ lastDelta = delta;
+ }
+ }
}
diff --git a/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java b/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java
index f673657..c81723d 100644
--- a/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java
+++ b/src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java
@@ -3,7 +3,12 @@
package com.lambdaworks.crypto.test;
import com.lambdaworks.codec.Base64;
+import com.lambdaworks.crypto.SCrypt;
import com.lambdaworks.crypto.SCryptUtil;
+import java.lang.management.ManagementFactory;
+import java.lang.management.ThreadMXBean;
+import java.security.GeneralSecurityException;
+import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
@@ -56,4 +61,27 @@ public void format_0_rp_max() throws Exception {
assertEquals(r, params >> 8 & 0xff);
assertEquals(p, params >> 0 & 0xff);
}
+
+ @Test
+ public void testTimedIterations() throws GeneralSecurityException {
+ byte[] salt = "1234".getBytes();
+ int dkLen = 32;
+
+ ThreadMXBean threadBean = ManagementFactory.getThreadMXBean();
+ boolean cpuTimeSupported = threadBean.isCurrentThreadCpuTimeSupported();
+ Random random = new Random();
+ for (int i=0; i<5; i++) {
+ int targetDuration = 100 + random.nextInt(900);
+ int numIterations = SCryptUtil.timedIterations(targetDuration, 8, 1);
+ long startTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
+ SCrypt.scrypt(passwd.getBytes(), salt, numIterations, 8, 1, dkLen);
+ long endTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
+ long actualDuration = (endTime-startTime) / 1000000;
+
+ // check that actual duration is within targetDuration - 50% and targetDuration + 60%
+ String failMessage = "Target duration=" + targetDuration + ", actual=" + actualDuration;
+ assertTrue(failMessage, actualDuration > targetDuration*0.5);
+ assertTrue(failMessage, actualDuration < targetDuration*1.6);
+ }
+ }
}