diff --git a/crypto-core/src/main/java/com/palantir/crypto2/cipher/Jdk8292158.java b/crypto-core/src/main/java/com/palantir/crypto2/cipher/Jdk8292158.java index fd5cd0e97..fbd2a2013 100644 --- a/crypto-core/src/main/java/com/palantir/crypto2/cipher/Jdk8292158.java +++ b/crypto-core/src/main/java/com/palantir/crypto2/cipher/Jdk8292158.java @@ -21,7 +21,10 @@ import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedSet; +import com.google.common.io.BaseEncoding; +import com.palantir.logsafe.Preconditions; import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.UnsafeArg; import com.palantir.logsafe.exceptions.SafeIllegalStateException; import com.palantir.logsafe.logger.SafeLogger; import com.palantir.logsafe.logger.SafeLoggerFactory; @@ -31,14 +34,21 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.NoSuchProviderException; import java.util.Arrays; import java.util.Comparator; import java.util.Objects; +import java.util.Random; import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; import java.util.function.BooleanSupplier; import java.util.function.Supplier; import java.util.stream.Stream; import javax.annotation.Nullable; +import javax.crypto.Cipher; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; /** * Determine if JVM is impacted by https://bugs.openjdk.org/browse/JDK-8292158 which can corrupt AES-CTR encryption @@ -100,6 +110,9 @@ static boolean isAffectedByJdkAesCtrCorruption(Version version, String architect @SuppressWarnings("checkstyle:CyclomaticComplexity") static boolean isAffectedByJdkAesCtrCorruption( Version version, String architecture, Info info, BooleanSupplier cpuHasAvx512) { + if (isAesCtrBroken()) { + return true; + } int featureVersion = version.feature(); if (featureVersion >= 20) { // https://git.openjdk.org/jdk/commit/9d76ac8a4453bc51d9dca2ad6c60259cfb2c4203 in jdk-20+17 @@ -196,4 +209,76 @@ static boolean hasVectorizedAesCpu(Stream lines) { .collect(ImmutableSortedSet.toImmutableSortedSet(Comparator.naturalOrder())); return flags.containsAll(jdk8292158ImpactedCpuFlags); } + + @VisibleForTesting + static boolean isAesCtrBroken() { + try { + for (int i = 8; i <= 32; i++) { + testEncryptDecrypt(i); + } + return false; + } catch (NoSuchProviderException e) { + log.warn("AES-CTR test failed due to no such provider", e); + return false; + } catch (GeneralSecurityException | Error | RuntimeException e) { + log.error("AES-CTR AES-CTR encryption/decryption round-trip failed", e); + return true; + } + } + + static void testEncryptDecrypt(int length) throws GeneralSecurityException { + Preconditions.checkArgument(length > 4, "length must be at least 4"); + + long seed = ThreadLocalRandom.current().nextLong(); + if (log.isDebugEnabled()) { + log.debug( + "Testing AES-CTR encryption/decryption for JDK-829158", + SafeArg.of("seed", seed), + SafeArg.of("length", length)); + } + + Random random = new Random(seed); + + byte[] key = new byte[32]; + random.nextBytes(key); + SecretKeySpec secretKeySpec = new SecretKeySpec(key, "AES"); + + byte[] iv = new byte[16]; + random.nextBytes(iv); + IvParameterSpec ivParameterSpec = new IvParameterSpec(iv); + + Cipher encrypt = Cipher.getInstance("AES/CTR/NoPadding"); + encrypt.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec); + + Cipher decrypt = Cipher.getInstance("AES/CTR/NoPadding"); + decrypt.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec); + + byte[] cleartext = new byte[length]; + byte[] encrypted = new byte[length]; + byte[] decrypted = new byte[length]; + + for (int i = 0; i < 10_000; i++) { + random.nextBytes(cleartext); + encrypt.doFinal(cleartext, 0, length, encrypted); + + // use decrypt cipher at least 3 times + decrypt.update(encrypted, 0, 1, decrypted, 0); + decrypt.update(encrypted, 1, 1, decrypted, 1); + decrypt.doFinal(encrypted, 2, length - 2, decrypted, 2); + + if (!Arrays.equals(cleartext, decrypted)) { + throw new SafeIllegalStateException( + "AES-CTR encryption/decryption round trip failed", + cannotEncryptAesCtrSafely(), + SafeArg.of("seed", seed), + SafeArg.of("length", length), + SafeArg.of("iteration", i), + UnsafeArg.of("cleartext", BaseEncoding.base16().encode(cleartext)), + UnsafeArg.of("decrypted", BaseEncoding.base16().encode(decrypted)), + UnsafeArg.of("encrypted", BaseEncoding.base16().encode(encrypted)), + UnsafeArg.of("key", BaseEncoding.base16().encode(key)), + UnsafeArg.of("iv", BaseEncoding.base16().encode(iv))); + } + } + } } diff --git a/crypto-core/src/test/java/com/palantir/crypto2/cipher/Jdk8292158Test.java b/crypto-core/src/test/java/com/palantir/crypto2/cipher/Jdk8292158Test.java index ac48e3234..f2f3c965f 100644 --- a/crypto-core/src/test/java/com/palantir/crypto2/cipher/Jdk8292158Test.java +++ b/crypto-core/src/test/java/com/palantir/crypto2/cipher/Jdk8292158Test.java @@ -19,6 +19,8 @@ import static com.palantir.logsafe.testing.Assertions.assertThatLoggableExceptionThrownBy; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assumptions.assumeThatThrownBy; +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -50,6 +52,21 @@ void aesCbcIsNotAffected() { .isFalse(); } + @Test + void throwsWhenAffected() { + assumeTrue(Jdk8292158.isAesCtrBroken()); + assumeThatThrownBy(() -> Jdk8292158.isAffectedByJdkAesCtrCorruption(AesCtrCipher.ALGORITHM)) + .isInstanceOf(SafeIllegalStateException.class) + .hasMessageContaining("JVM and CPU architecture is affected by JDK-8292158"); + } + + @Test + void doesNotThrowWhenNotAffected() { + assumeFalse(Jdk8292158.isAesCtrBroken()); + assertThat(Jdk8292158.isAffectedByJdkAesCtrCorruption(AesCtrCipher.ALGORITHM)) + .isFalse(); + } + @Test void aesCtrMayBeAffected() { assumeThatThrownBy(() -> assertThat(Jdk8292158.isAffectedByJdkAesCtrCorruption(AesCtrCipher.ALGORITHM))