Skip to content
This repository was archived by the owner on May 25, 2021. It is now read-only.

Method for calculating a value for N depending on CPU speed #12

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/main/java/com/lambdaworks/crypto/SCryptUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package com.lambdaworks.crypto;

import java.io.UnsupportedEncodingException;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;

Expand All @@ -28,6 +30,12 @@
* @author Will Glozer
*/
public class SCryptUtil {
// for timedIterations()
private static final byte[] BENCH_PASSWD = "secret".getBytes();
private static final byte[] BENCH_SALT = "1234".getBytes();
private static final int BENCH_DK_LEN = 32;
private static final int BENCH_INITIAL_N = 64;

/**
* Hash the supplied plaintext password and generate output in the format described
* in {@link SCryptUtil}.
Expand Down Expand Up @@ -109,4 +117,59 @@ private static int log2(int n) {
if (n >= 4 ) { n >>>= 2; log += 2; }
return log + (n >>> 1);
}

/**
* Determines a CPU cost value (i.e. a value for the N parameter) that will cause password
* verification to take (roughly) a given time on the current CPU for the specified
* <code>r</code> and <code>p</code> values.<br/>
* N is rounded to the nearest power of two because only powers of two are valid
* choices for N. The actual time spent will be between about .7*<code>milliseconds</code>
* and 1.4*<code>milliseconds</code>.
*
* @param milliseconds the time scrypt should spend verifying a password
* @param r memory cost parameter
* @param p parallelization parameter
*
* @return a value for N such that <code>scrypt(N, r, p)</code> runs for roughly <code>milliseconds</code>
*
* @throws GeneralSecurityException when HMAC_SHA256 is not available.
*/
public static int timedIterations(int milliseconds, int r, int p) throws GeneralSecurityException {
ThreadMXBean threadBean = ManagementFactory.getThreadMXBean();
boolean cpuTimeSupported = threadBean.isCurrentThreadCpuTimeSupported();
boolean origEnabledFlag = false;
if (cpuTimeSupported) {
origEnabledFlag = threadBean.isThreadCpuTimeEnabled();
if (!origEnabledFlag)
threadBean.setThreadCpuTimeEnabled(true);
}

int N = BENCH_INITIAL_N;
long lastDelta = 0;
while (true) {
// prefer CPU time over real world time so the result is load independent
long startTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
SCrypt.scrypt(BENCH_PASSWD, BENCH_SALT, N, r, p, BENCH_DK_LEN);
long endTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
long delta = (endTime-startTime) / 1000000;

// start over if a speed increase is detected due to the code being JITted
if (delta < lastDelta) {
N = BENCH_INITIAL_N;
lastDelta = 0;
continue;
}

if (delta > milliseconds) {
if (cpuTimeSupported)
threadBean.setThreadCpuTimeEnabled(origEnabledFlag);
// round to the nearest power of two
if (delta-delta/4 > milliseconds)
N /= 2;
return N;
}
N *= 2;
lastDelta = delta;
}
}
}
28 changes: 28 additions & 0 deletions src/test/java/com/lambdaworks/crypto/test/SCryptUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
package com.lambdaworks.crypto.test;

import com.lambdaworks.codec.Base64;
import com.lambdaworks.crypto.SCrypt;
import com.lambdaworks.crypto.SCryptUtil;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.security.GeneralSecurityException;
import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
Expand Down Expand Up @@ -56,4 +61,27 @@ public void format_0_rp_max() throws Exception {
assertEquals(r, params >> 8 & 0xff);
assertEquals(p, params >> 0 & 0xff);
}

@Test
public void testTimedIterations() throws GeneralSecurityException {
byte[] salt = "1234".getBytes();
int dkLen = 32;

ThreadMXBean threadBean = ManagementFactory.getThreadMXBean();
boolean cpuTimeSupported = threadBean.isCurrentThreadCpuTimeSupported();
Random random = new Random();
for (int i=0; i<5; i++) {
int targetDuration = 100 + random.nextInt(900);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is your target duration randomly choosen ?
From what I've been taught this is not really a good practice in unit testing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything in particular you're thinking of?
I'd say the benefit is that it covers a wider range of inputs. The downside is that test failures are harder to reproduce.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly what I mean.
Beside you're not tracking the currently choosen value.

One thing you could do is running the two tests : one with a fixed value and another randomly choosen.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a point about the value not being logged. I'll add that.

int numIterations = SCryptUtil.timedIterations(targetDuration, 8, 1);
long startTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
SCrypt.scrypt(passwd.getBytes(), salt, numIterations, 8, 1, dkLen);
long endTime = cpuTimeSupported ? threadBean.getCurrentThreadUserTime() : System.nanoTime();
long actualDuration = (endTime-startTime) / 1000000;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you take a time in nano an divide it to a lower resolution ?
(femto if I'm correct)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.


// check that actual duration is within targetDuration - 50% and targetDuration + 60%
String failMessage = "Target duration=" + targetDuration + ", actual=" + actualDuration;
assertTrue(failMessage, actualDuration > targetDuration*0.5);
assertTrue(failMessage, actualDuration < targetDuration*1.6);
}
}
}