Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into merge-upstream-again
Browse files Browse the repository at this point in the history
  • Loading branch information
julienrf committed Aug 27, 2024
2 parents 3a58ec8 + b4edd13 commit 6d41de7
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
Expand Down Expand Up @@ -462,13 +462,7 @@ protected AwsCredentialsProvider getAwsCredentialsProvider(Configuration conf) {
// initialized
String providerClass = conf.get(DynamoDBConstants.CUSTOM_CREDENTIALS_PROVIDER_CONF);
if (!Strings.isNullOrEmpty(providerClass)) {
try {
providersList.add(
(AwsCredentialsProvider) ReflectionUtils.newInstance(Class.forName(providerClass), conf)
);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Custom AWSCredentialsProvider not found: " + providerClass, e);
}
providersList.add(DynamoDBUtil.loadAwsCredentialsProvider(providerClass, conf));
}

// try to fetch credentials from core-site
Expand All @@ -485,7 +479,8 @@ protected AwsCredentialsProvider getAwsCredentialsProvider(Configuration conf) {
}

if (Strings.isNullOrEmpty(accessKey) || Strings.isNullOrEmpty(secretKey)) {
providersList.add(InstanceProfileCredentialsProvider.create());
log.debug("Custom credential provider not found, loading default provider from sdk");
providersList.add(DefaultCredentialsProvider.create());
} else if (!Strings.isNullOrEmpty(sessionKey)) {
final AwsCredentials credentials =
AwsSessionCredentials.create(accessKey, secretKey, sessionKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.dynamodb.util.ClusterTopologyNodeCapacityProvider;
import org.apache.hadoop.dynamodb.util.DynamoDBReflectionUtils;
import org.apache.hadoop.dynamodb.util.NodeCapacityProvider;
import org.apache.hadoop.dynamodb.util.RoundRobinYarnContainerAllocator;
import org.apache.hadoop.dynamodb.util.TaskCalculator;
Expand All @@ -49,6 +50,7 @@
import org.apache.hadoop.mapred.JobConf;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.internal.util.EC2MetadataUtils;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
Expand Down Expand Up @@ -279,6 +281,32 @@ public static String getDynamoDBRegion(Configuration conf, String region) {
return DynamoDBConstants.DEFAULT_AWS_REGION;
}

/**
*
* Utility method to load an aws credentials provider from config via reflection. There are two
* strategies followed:
* 1. Load credential provider via its 'create' method.
* This is the intended credential provider construction mechanism with aws-java-sdk-v2
* For more information, visit {@link <a href="https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/migration-client-credentials.html">Credential Provider Changes</a>}.
* 2. If 'create' method is not found, fallback to default no-arg constructor.
* This is kept to ensure utility method maintains backwards compatibility with what it
* used to support.
*
* @param providerClass - class name loaded from conf used as custom credential provider
* @return - credential provider loaded via reflection using class name from conf
*/
public static AwsCredentialsProvider loadAwsCredentialsProvider(
String providerClass,
Configuration conf) {
if (DynamoDBReflectionUtils.hasFactoryMethod(providerClass, "create")) {
log.debug("Provider: " + providerClass + " contains required method for creation - create()");
return DynamoDBReflectionUtils.createInstanceFromFactory(providerClass, conf, "create");
} else {
log.debug("Falling back to default constructor.");
return DynamoDBReflectionUtils.createInstanceOf(providerClass, conf);
}
}

public static JobClient createJobClient(JobConf jobConf) {
try {
return new JobClient(jobConf);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/**
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
* except in compliance with the License. A copy of the License is located at
*
*     http://aws.amazon.com/apache2.0/
*
* or in the "LICENSE.TXT" file accompanying this file. This file is distributed on an "AS IS"
* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under the License.
*/

package org.apache.hadoop.dynamodb.util;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Optional;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.ReflectionUtils;

/**
* Reflected-utility methods for DynamoDB connector pkg
*/
public class DynamoDBReflectionUtils {

private static final Log log = LogFactory.getLog(DynamoDBReflectionUtils.class);

// Default no-arg constructor reflection logic
public static <T> T createInstanceOf(String className, Configuration conf) {
return createInstanceOfWithParams(className, conf, null, null);
}

// constructor with-args reflection logic
@SuppressWarnings("unchecked")
public static <T> T createInstanceOfWithParams(
String className,
Configuration conf,
Class<?>[] paramTypes,
Object[] params) {
try {
Class<?> clazz = getClass(className);
Constructor<T> ctor = paramTypes == null
? (Constructor<T>) clazz.getDeclaredConstructor()
: (Constructor<T>) clazz.getDeclaredConstructor(paramTypes);
ctor.setAccessible(true);
T instance = ctor.newInstance(params);
log.info("Successfully loaded class: " + className);
ReflectionUtils.setConf(instance, conf);
log.debug("Configured instance to use conf");
return instance;

} catch (NoSuchMethodException | InvocationTargetException e) {
throw new RuntimeException("Unable to find constructor of class: " + className, e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Class being loaded is not accessible: " + className, e);
} catch (InstantiationException e) {
throw new RuntimeException("Unable to instantiate class: " + className, e);
}
}

// checks if class has method available in it
public static boolean hasFactoryMethod(String className, String methodName) {
Class<?> clazz = getClass(className);
return Arrays.stream(clazz.getMethods())
.anyMatch(method -> method.getName().equals(methodName));
}

// factory-based reflection logic that uses a method for object construction
@SuppressWarnings("unchecked")
public static <T> T createInstanceFromFactory(
String className,
Configuration conf,
String methodName) {
try {
Class<?> clazz = getClass(className);
Method m = clazz.getDeclaredMethod(methodName);
m.setAccessible(true);
T instance = (T) m.invoke(null);
log.info("Successfully loaded class: " + className);
ReflectionUtils.setConf(instance, conf);
log.debug("Configured instance to use conf");
return instance;

} catch (NoSuchMethodException e) {
log.error("Method not found for object construction: " + methodName);
throw new RuntimeException("Unable to find static method to load class: " + className, e);
} catch (InvocationTargetException e) {
log.error("Exception found when invoking method for object construction: " + methodName);
throw new RuntimeException("Unable to load class: " + className, e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Class being loaded is not accessible: " + className, e);
}
}

// checks if class can be loaded
private static Class<?> getClass(String className) {
try {
return Class.forName(className, true, getContextOrDefaultClassLoader());
} catch (ClassNotFoundException e) {
throw new RuntimeException("Unable to locate class to load via reflection: " + className, e);
}
}

private static ClassLoader getContextOrDefaultClassLoader() {
return Optional.of(Thread.currentThread().getContextClassLoader())
.orElseGet(DynamoDBReflectionUtils::getDefaultClassLoader);
}

private static ClassLoader getDefaultClassLoader() {
return DynamoDBReflectionUtils.class.getClassLoader();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import static org.mockito.Mockito.mock;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import org.apache.hadoop.conf.Configurable;
Expand All @@ -28,7 +27,6 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.Mockito;

import java.net.URI;
Expand All @@ -37,11 +35,16 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Map;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClientBuilder;
Expand Down Expand Up @@ -101,7 +104,7 @@ public void testDefaultCredentials() {
}

@Test
public void testCustomCredentialsProvider() {
public void testCustomCredentialsProviderWithConstructor() {
final String MY_ACCESS_KEY = "abc";
final String MY_SECRET_KEY = "xyz";
Configuration conf = new Configuration();
Expand All @@ -116,6 +119,22 @@ public void testCustomCredentialsProvider() {
Assert.assertEquals(MY_SECRET_KEY, provider.resolveCredentials().secretAccessKey());
}

@Test
public void testCustomCredentialsProviderWithMethod() {
final String MY_ACCESS_KEY = "abc";
final String MY_SECRET_KEY = "xyz";
Configuration conf = new Configuration();
conf.set("my.accessKey", MY_ACCESS_KEY);
conf.set("my.secretKey", MY_SECRET_KEY);
conf.set(DynamoDBConstants.CUSTOM_CREDENTIALS_PROVIDER_CONF, MyFactoryCredentialsProvider.class
.getName());

DynamoDBClient dynamoDBClient = new DynamoDBClient();
AwsCredentialsProvider provider = dynamoDBClient.getAwsCredentialsProvider(conf);
Assert.assertEquals(MY_ACCESS_KEY, provider.resolveCredentials().accessKeyId());
Assert.assertEquals(MY_SECRET_KEY, provider.resolveCredentials().secretAccessKey());
}

@Test
public void testCustomProviderNotFound() {
Configuration conf = new Configuration();
Expand Down Expand Up @@ -176,6 +195,24 @@ public void testBasicSessionCredentials(){

}

@Test
public void testDefaultCredentialProvider() {
DynamoDBClient dynamoDBClient = new DynamoDBClient();
AwsCredentialsProvider provider = dynamoDBClient.getAwsCredentialsProvider(conf);
Assert.assertTrue(provider instanceof AwsCredentialsProviderChain);
AwsCredentialsProviderChain providerChain = (AwsCredentialsProviderChain) provider;
try {
Field providersField = AwsCredentialsProviderChain.class.getDeclaredField("credentialsProviders");
providersField.setAccessible(true);
@SuppressWarnings("unchecked")
List<AwsCredentialsProvider> providers = (List<AwsCredentialsProvider>) providersField.get(providerChain);
Assert.assertEquals(1, providers.size());
Assert.assertTrue(providers.get(0) instanceof DefaultCredentialsProvider);
} catch (Exception e) {
Assert.fail("Unexpected error thrown: " + e.getMessage());
}
}

@Test
public void setsClientConfigurationProxyHostAndPortWhenBothAreSupplied() {
setTestProxyHostAndPort(conf);
Expand Down Expand Up @@ -351,6 +388,7 @@ private void setProxyUsernameAndPassword(Configuration conf, String username, St
}
}

// Default Constructor-based credential provider
private static class MyAWSCredentialsProvider implements AwsCredentialsProvider, Configurable {
private Configuration conf;
private String accessKey;
Expand Down Expand Up @@ -378,6 +416,34 @@ public void setConf(Configuration configuration) {
}
}

// Method-based constructor credential provider
private static class MyFactoryCredentialsProvider implements AwsCredentialsProvider, Configurable {
private Configuration conf;
private String accessKey;
private String secretKey;

public static MyFactoryCredentialsProvider create() {
return new MyFactoryCredentialsProvider();
}

@Override
public AwsCredentials resolveCredentials() {
return AwsBasicCredentials.create(accessKey, secretKey);
}

@Override
public Configuration getConf() {
return this.conf;
}

@Override
public void setConf(Configuration configuration) {
this.conf = configuration;
accessKey = conf.get("my.accessKey");
secretKey = conf.get("my.secretKey");
}
}

private static class MyDynamoDbClientBuilderTransformer implements DynamoDbClientBuilderTransformer, Configurable {

private Configuration conf;
Expand Down

0 comments on commit 6d41de7

Please sign in to comment.