in your home directory (instead of stdout)."
+ " This will overwrite any existing file with the same name";
- private static final String SSH_TUNNEL_SERVICE_OPTION_DESCRIPTION =
- "Starts an SSH Tunnel service.";
// Messages string constants
public static final String DUPLICATE_COLUMN_KEY_DETECTED_FOR_TABLE_SCHEMA =
"Duplicate column key '%s' detected for table schema '%s'. Original column '%s'."
@@ -247,7 +241,6 @@ public class DocumentDbMain {
LIBRARY_NAME = getLibraryName();
HELP_OPTION = buildHelpOption();
VERSION_OPTION = buildVersionOption();
- SSH_TUNNEL_SERVICE_OPTIONS = buildSshTunnelServiceOption();
COMMAND_OPTIONS = buildCommandOptions();
REQUIRED_OPTIONS = buildRequiredOptions();
OPTIONAL_OPTIONS = buildOptionalOptions();
@@ -348,13 +341,6 @@ static void handleCommandLine(final String[] args, final StringBuilder output)
}
try {
final CommandLineParser parser = new DefaultParser();
- // First check for the SSH tunnel service option separately from the other options.
- final CommandLine commandLineSshTunnelService = parser.parse(SSH_TUNNEL_SERVICE_OPTIONS, args, true);
- if (commandLineSshTunnelService.hasOption(SSH_TUNNEL_SERVICE_OPTION_NAME)) {
- performSshTunnelService(commandLineSshTunnelService, output);
- return;
- }
- // Otherwise, consider the "complete" options for metadata options.
final CommandLine commandLine = parser.parse(COMPLETE_OPTIONS, args);
final DocumentDbConnectionProperties properties = new DocumentDbConnectionProperties();
if (!tryGetConnectionProperties(commandLine, properties, output)) {
@@ -420,30 +406,6 @@ private static void closeClient() {
}
}
- private static void performSshTunnelService(
- final CommandLine commandLine,
- final StringBuilder output) throws DuplicateKeyException {
- try (DocumentDbSshTunnelService service = new DocumentDbSshTunnelService(
- commandLine.getOptionValue(SSH_TUNNEL_SERVICE_OPTION_NAME))) {
- final Thread serviceThread = new Thread(service);
- serviceThread.setDaemon(true);
- serviceThread.start();
- do {
- serviceThread.join(1000);
- } while (serviceThread.isAlive());
- service.getExceptions().forEach(
- e -> output
- .append(e.getMessage())
- .append(System.lineSeparator())
- .append(Arrays.stream(e.getStackTrace())
- .map(StackTraceElement::toString)
- .collect(Collectors.joining(System.lineSeparator())))
- .append(System.lineSeparator()));
- } catch (Exception e) {
- output.append(e.getMessage());
- }
- }
-
private static void performImport(
final CommandLine commandLine,
final DocumentDbConnectionProperties properties,
@@ -1034,16 +996,6 @@ private static Option buildVersionOption() {
.build();
}
- private static Options buildSshTunnelServiceOption() {
- return new Options().addOption(
- Option.builder()
- .longOpt(SSH_TUNNEL_SERVICE_OPTION_NAME)
- .desc(SSH_TUNNEL_SERVICE_OPTION_DESCRIPTION)
- .numberOfArgs(1)
- .argName(SSH_TUNNEL_SERVICE_ARG_NAME)
- .build());
- }
-
private static Option buildHelpOption() {
return Option.builder(HELP_OPTION_FLAG)
.longOpt(HELP_OPTION_NAME)
diff --git a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbMultiThreadFileChannel.java b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbMultiThreadFileChannel.java
deleted file mode 100644
index dcd62502..00000000
--- a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbMultiThreadFileChannel.java
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Copyright <2022> Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License").
- * You may not use this file except in compliance with the License.
- * A copy of the License is located at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * or in the "license" file accompanying this file. This file is distributed
- * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- *
- */
-
-package software.amazon.documentdb.jdbc.sshtunnel;
-
-import org.checkerframework.checker.nullness.qual.NonNull;
-
-import java.io.IOException;
-import java.nio.channels.FileChannel;
-import java.nio.channels.FileLock;
-import java.nio.channels.NonWritableChannelException;
-import java.nio.channels.OverlappingFileLockException;
-import java.nio.file.OpenOption;
-import java.nio.file.Path;
-import java.util.concurrent.TimeUnit;
-
-/**
- * An implementation of a {@link FileChannel}-like class to handle locking of
- * processes and threads.
- */
-final class DocumentDbMultiThreadFileChannel implements AutoCloseable {
- private final FileChannel fileChannel;
-
- /**
- * Constructs a new instance of {@link DocumentDbMultiThreadFileChannel}.
- *
- * @param fileChannel the underlying {@link FileChannel} to use.
- */
- private DocumentDbMultiThreadFileChannel(final @NonNull FileChannel fileChannel) {
- this.fileChannel = fileChannel;
- }
-
- @Override
- public void close() throws Exception {
- fileChannel.close();
- }
-
- /**
- * Opens or creates a file, returning a file channel to access the file
- * where options is a set of the options specified in the options array.
- *
- * @param path The path of the file to open or create
- * @param options Options specifying how the file is opened
- * @return A new file channel
- * @throws IOException If an I/O error occurs
- */
- public static DocumentDbMultiThreadFileChannel open(final Path path, final OpenOption... options)
- throws IOException {
- final FileChannel fileChannel = FileChannel.open(path, options);
- return new DocumentDbMultiThreadFileChannel(fileChannel);
- }
-
- /**
- * Attempts to acquire an exclusive lock on this channel's file.
- * An invocation of this method of the form fc.tryLock() behaves in exactly the same way as the invocation
- *
- * fc.tryLock(0L, Long.MAX_VALUE, false)
- *
- * Additionally, this implementation treats the {@link OverlappingFileLockException} the same as a
- * concurrent lock on the file. In this case, this method automatically returns null indicated the file is locked.
- *
- * @return A lock object representing the newly-acquired lock, or null if the lock could not be acquired because
- * another program holds an overlapping lock
- * @throws IOException If some other I/O error occurs
- */
- public FileLock tryLock() throws IOException {
- try {
- return fileChannel.tryLock();
- } catch (OverlappingFileLockException | NonWritableChannelException e) {
- return null;
- }
- }
-
- /**
- * Acquires an exclusive lock on this channel's file.
- * An invocation of this method of the form fc.lock() behaves in exactly the same way as the invocation
- *
- * fc.lock(0L, Long.MAX_VALUE, false)
- *
- * Additionally, this implementation treats the {@link OverlappingFileLockException} the same as a
- * concurrent lock on the file. It automatically retries until the lock is obtained.
- *
- * @return a {@link FileLock} object.
- * @throws IOException if the file lock fails.
- * @throws InterruptedException if the thread is interrupted while sleeping.
- */
- @NonNull
- public FileLock lock() throws IOException, InterruptedException {
- return lock(10);
- }
-
- /**
- * Acquires an exclusive lock on this channel's file.
- * An invocation of this method of the form fc.lock() behaves in exactly the same way as the invocation
- *
- * fc.lock(0L, Long.MAX_VALUE, false)
- *
- * Additionally, this implementation treats the {@link OverlappingFileLockException} the same as a
- * concurrent lock on the file. It automatically retries until the lock is obtained.
- *
- * @param pollTimeMS the amount of time, in milliseconds, to sleep between retries in the case an
- * {@link OverlappingFileLockException} is detected.
- *
- * @return A lock object representing the newly-acquired lock
- * @throws IOException If the file lock fails.
- * @throws InterruptedException if the thread is interrupted while sleeping.
- */
- @NonNull
- public FileLock lock(final int pollTimeMS) throws IOException, InterruptedException {
- FileLock fileLock;
- do {
- try {
- fileLock = fileChannel.lock();
- } catch (OverlappingFileLockException | NonWritableChannelException e) {
- // This is meant to handle multiple threads locking a single file.
- TimeUnit.MILLISECONDS.sleep(pollTimeMS);
- fileLock = null;
- }
- } while (fileLock == null);
- return fileLock;
- }
-
- public boolean isOpen() {
- return fileChannel.isOpen();
- }
-
- FileChannel getFileChannel() {
- return fileChannel;
- }
-}
diff --git a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClient.java b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClient.java
index 7e5d5771..648da949 100644
--- a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClient.java
+++ b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClient.java
@@ -18,17 +18,10 @@
import com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.NonNull;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties;
-import java.io.IOException;
-import java.nio.channels.FileLock;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.StandardOpenOption;
import java.sql.SQLException;
-import java.util.UUID;
+import java.util.concurrent.atomic.AtomicBoolean;
import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.isNullOrWhitespace;
@@ -37,47 +30,30 @@
* a single instance SSH tunnel is started and stays running while this object is alive.
*/
public class DocumentDbSshTunnelClient implements AutoCloseable {
- private static final Logger LOGGER = LoggerFactory.getLogger(DocumentDbSshTunnelClient.class);
- private final Object mutex = new Object();
-
- private final String propertiesHashString;
private final DocumentDbSshTunnelServer sshTunnelServer;
-
- private volatile FileLock clientLock = null;
- private volatile DocumentDbMultiThreadFileChannel clientChannel = null;
- private volatile Path clientLockPath = null;
+ private final AtomicBoolean closed;
+ private final Object lock = new Object();
/**
* Creates a new SSH Tunnel client object from the given connection properties.
*
* @param properties The connection properties for this SSH Tunnel.
- * @throws Exception When an error occurs attempting to ensure an SSH Tunnel instance is running.
+ * @throws SQLException When an error occurs attempting to ensure an SSH Tunnel instance is running.
*/
public DocumentDbSshTunnelClient(final @NonNull DocumentDbConnectionProperties properties)
- throws Exception {
+ throws SQLException {
validateSshTunnelProperties(properties);
- this.propertiesHashString = DocumentDbSshTunnelLock.getHashString(
- properties.getSshUser(),
- properties.getSshHostname(),
- properties.getSshPrivateKeyFile(),
- properties.getHostname());
-
- try {
- ensureClientLocked();
- sshTunnelServer = DocumentDbSshTunnelServer.builder(
- properties.getSshUser(),
- properties.getSshHostname(),
- properties.getSshPrivateKeyFile(),
- properties.getHostname())
- .sshPrivateKeyPassphrase(properties.getSshPrivateKeyPassphrase())
- .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
- .sshKnownHostsFile(properties.getSshKnownHostsFile())
- .build();
- sshTunnelServer.addClient();
- } catch (Exception e) {
- ensureClientUnlocked();
- throw e;
- }
+ sshTunnelServer = DocumentDbSshTunnelServer.builder(
+ properties.getSshUser(),
+ properties.getSshHostname(),
+ properties.getSshPrivateKeyFile(),
+ properties.getHostname())
+ .sshPrivateKeyPassphrase(properties.getSshPrivateKeyPassphrase())
+ .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
+ .sshKnownHostsFile(properties.getSshKnownHostsFile())
+ .build();
+ sshTunnelServer.addClient();
+ closed = new AtomicBoolean(false);
}
private static void validateSshTunnelProperties(final DocumentDbConnectionProperties properties)
@@ -105,9 +81,8 @@ public int getServiceListeningPort() {
* Gets indicator of whether the SSH Tunnel server is alive.
*
* @return Returns true if the server is alive, false otherwise.
- * @throws Exception When an error occurs trying to determine if the server is alive.
*/
- public boolean isServerAlive() throws Exception {
+ public boolean isServerAlive() {
return getSshTunnelServer().isAlive();
}
@@ -115,13 +90,16 @@ public boolean isServerAlive() throws Exception {
* Closes the client object by unlocking and deleting the client lock file. If this is the last client
* for the server, the SSH Tunnel server will be shutdown.
*
- * @throws Exception When an error occurs unlocking the client lock file or shutting down the server.
+ * @throws SQLException When an error occurs closing the session.
*/
@Override
- public void close() throws Exception {
- synchronized (mutex) {
- ensureClientUnlocked();
+ public void close() throws SQLException {
+ synchronized (lock) {
+ if (closed.get()) {
+ return;
+ }
sshTunnelServer.removeClient();
+ closed.set(true);
}
}
@@ -134,88 +112,4 @@ public void close() throws Exception {
@NonNull DocumentDbSshTunnelServer getSshTunnelServer() {
return sshTunnelServer;
}
-
- /**
- * Ensures the client lock file is created and locked.
- *
- * @throws Exception When an error occurs trying to create and lock the client lock file.
- */
- private void ensureClientLocked() throws Exception {
- initializeClientLockFolder();
- final Exception exception = DocumentDbSshTunnelLock.runInGlobalLock(propertiesHashString, this::lockClientFile);
- if (exception != null) {
- throw exception;
- }
- }
-
- /**
- * Initializes the client by ensuring the parent folder exists, gets a UUID for this client lock file.
- *
- * @throws IOException When an error occurs trying to create the parent directories.
- */
- private void initializeClientLockFolder() throws IOException {
- final UUID unique = UUID.randomUUID();
- clientLockPath = DocumentDbSshTunnelLock.getClientLockPath(unique, propertiesHashString);
- final Path parentPath = clientLockPath.getParent();
- assert parentPath != null;
- Files.createDirectories(parentPath);
- }
-
- /**
- * Locks the client lock file. Assumes it is run inside the global lock context.
- *
- * @return An Exception if an error occurs locking the client lock file, null otherwise.
- */
- private Exception lockClientFile() {
- Exception e = null;
- try {
- clientChannel = DocumentDbMultiThreadFileChannel.open(
- clientLockPath, StandardOpenOption.CREATE_NEW, StandardOpenOption.WRITE);
- clientLock = clientChannel.lock();
- LOGGER.debug("SSH Tunnel server client lock active.");
- } catch (Exception ex) {
- e = ex;
- }
- return e;
- }
-
- /**
- * Ensures the client is unlocked. Is run in the global lock context.
- *
- * @throws Exception When an error occurs unlocking the client lock file.
- */
- private void ensureClientUnlocked() throws Exception {
- final Exception exception = DocumentDbSshTunnelLock.runInGlobalLock(
- propertiesHashString, this::unlockClientFile);
- if (exception != null) {
- throw exception;
- }
- }
-
- /**
- * Unlocks the client lock file and cleans up the file from the folder.
- *
- * @return An exception if an error occurs unlocking the click lock file, null otherwise.
- */
- private Exception unlockClientFile() {
- Exception exception = null;
- try {
- if (clientLock != null && clientLock.isValid()) {
- clientLock.close();
- LOGGER.debug("SSH Tunnel server client lock inactive.");
- }
- if (clientChannel != null && clientChannel.isOpen()) {
- clientChannel.close();
- }
- if (clientLockPath != null) {
- Files.deleteIfExists(clientLockPath);
- }
- } catch (Exception e) {
- exception = e;
- }
- clientLock = null;
- clientChannel = null;
- clientLockPath = null;
- return exception;
- }
}
diff --git a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelLock.java b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelLock.java
deleted file mode 100644
index b86dcfc1..00000000
--- a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelLock.java
+++ /dev/null
@@ -1,207 +0,0 @@
-/*
- * Copyright <2022> Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License").
- * You may not use this file except in compliance with the License.
- * A copy of the License is located at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * or in the "license" file accompanying this file. This file is distributed
- * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- *
- */
-
-package software.amazon.documentdb.jdbc.sshtunnel;
-
-import com.google.common.hash.Hashing;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.channels.FileLock;
-import java.nio.charset.StandardCharsets;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.nio.file.StandardOpenOption;
-import java.util.Comparator;
-import java.util.UUID;
-import java.util.function.Supplier;
-import java.util.stream.Stream;
-
-/**
- * The DocumentDbSshTunnelLock provides various static methods to support
- * file locking for the SSH tunnel implementation.
- */
-public final class DocumentDbSshTunnelLock {
- static final String LOCK_BASE_FOLDER_NAME = getLockBaseFolderName();
- static final String PORT_LOCK_NAME = ".sshTunnelLockPort";
- static final String STARTUP_LOCK_NAME = ".sshTunnelLockStartup";
- static final String SERVER_LOCK_NAME = ".sshTunnelLockServer";
- static final String CLIENT_LOCK_FOLDER_NAME = "clients";
- private static final String GLOBAL_LOCK_NAME = ".sshTunnelLockGlobal";
- private static final String CLIENT_LOCK_NAME = ".sshTunnelLockClient";
- private static String classPathLocationName = null;
-
- private DocumentDbSshTunnelLock() {
- // Empty by design
- }
-
- /**
- * Gets the hash string for the SSH properties provided.
- *
- * @param sshUser the username credential for the SSH tunnel.
- * @param sshHostname the hostname (or IP address) for the SSH tunnel.
- * @param sshPrivateKeyFile the path to the private key file.
- *
- * @return a String value representing the hash of the given properties.
- */
- static String getHashString(
- final String sshUser,
- final String sshHostname,
- final String sshPrivateKeyFile,
- final String remoteHostname) {
- final String sshPropertiesString = sshUser + "-" + sshHostname + "-" + sshPrivateKeyFile + remoteHostname;
- return Hashing.sha256()
- .hashString(sshPropertiesString, StandardCharsets.UTF_8)
- .toString();
- }
-
- /**
- * Runs a {@link Supplier} method within a "global lock".
- *
- * @param propertiesHashString the SSH properties hash string.
- * @param function the lambda function to execute.
- * @return the value returned from the lambda function.
- * @param the return value type.
- * @throws Exception thrown if an error occurs when attaining the "global lock".
- */
- static R runInGlobalLock(final @NonNull String propertiesHashString, final @NonNull Supplier function)
- throws Exception {
- final Path globalLockPath = getGlobalLockPath(propertiesHashString);
- final Path parentPath = globalLockPath.getParent();
- assert parentPath != null;
- Files.createDirectories(parentPath);
- try (DocumentDbMultiThreadFileChannel globalChannel = DocumentDbMultiThreadFileChannel.open(
- globalLockPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE);
- FileLock ignored = globalChannel.lock()) {
- return function.get();
- } // Note: this releases the lock, too.
- }
-
- static @NonNull Path getGlobalLockPath(final @NonNull String propertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- propertiesHashString,
- GLOBAL_LOCK_NAME);
- }
-
- static @NonNull Path getServerLockPath(final @NonNull String propertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- propertiesHashString,
- SERVER_LOCK_NAME
- );
- }
-
- static @NonNull Path getStartupLockPath(final @NonNull String sshPropertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- sshPropertiesHashString,
- STARTUP_LOCK_NAME);
- }
-
- static @NonNull Path getPortLockPath(final @NonNull String propertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- propertiesHashString,
- PORT_LOCK_NAME
- );
- }
-
- static @NonNull Path getLockDirectoryPath(final @NonNull String propertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- propertiesHashString
- );
- }
-
- static void deleteStartupAndPortLockFiles(
- final @NonNull Path startupLockPath,
- final @NonNull Path portLockPah) throws IOException {
- Files.deleteIfExists(portLockPah);
- Files.deleteIfExists(startupLockPath);
- }
-
- static void deleteLockDirectory(final @NonNull String propertiesHashString) throws IOException {
- final Path lockDirectoryPath = getLockDirectoryPath(propertiesHashString);
- deleteDirectoryRecursive(lockDirectoryPath);
- }
-
- private static void deleteDirectoryRecursive(final @NonNull Path directoryPath) throws IOException {
- if (!Files.exists(directoryPath)) {
- return;
- }
- try (Stream pathStream = Files.walk(directoryPath)) {
- pathStream.sorted(Comparator.reverseOrder())
- .map(Path::toFile)
- .forEach(File::delete);
-
- }
- }
-
- static @NonNull Path getClientLockPath(final @NonNull UUID unique, final @NonNull String propertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- propertiesHashString,
- CLIENT_LOCK_FOLDER_NAME,
- CLIENT_LOCK_NAME + "-" + unique);
- }
-
- static Path getClientsFolderPath(final String sshPropertiesHashString) {
- return Paths.get(
- LOCK_BASE_FOLDER_NAME,
- sshPropertiesHashString,
- CLIENT_LOCK_FOLDER_NAME);
- }
-
- private static String getLockBaseFolderName() {
- return Paths.get(
- getDocumentdbHomePathName(),
- "sshTunnelLocks").toString();
- }
-
- /**
- * Gets the ~/.documentdb path name.
- *
- * @return the ~/.documentdb path name.
- */
- public static String getDocumentdbHomePathName() {
- return DocumentDbConnectionProperties.DOCUMENTDB_HOME_PATH_NAME;
- }
-
- /**
- * Gets the class path's location name.
- *
- * @return the class path's location name.
- */
- public static String getClassPathLocationName() {
- if (classPathLocationName == null) {
- classPathLocationName = DocumentDbConnectionProperties.getClassPathLocation();
- }
- return classPathLocationName;
- }
-
- /**
- * Gets the user's home path name.
- *
- * @return the user's home path name.
- */
- static String getUserHomePathName() {
- return DocumentDbConnectionProperties.USER_HOME_PATH_NAME;
- }
-}
diff --git a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServer.java b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServer.java
index 22e7d7d4..9c14a4c9 100644
--- a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServer.java
+++ b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServer.java
@@ -16,42 +16,36 @@
package software.amazon.documentdb.jdbc.sshtunnel;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.hash.Hashing;
+import com.jcraft.jsch.HostKey;
+import com.jcraft.jsch.HostKeyRepository;
+import com.jcraft.jsch.JSch;
+import com.jcraft.jsch.JSchException;
+import com.jcraft.jsch.Session;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties;
-import software.amazon.documentdb.jdbc.DocumentDbMain;
import software.amazon.documentdb.jdbc.common.utilities.SqlError;
import software.amazon.documentdb.jdbc.common.utilities.SqlState;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.net.URI;
-import java.net.URISyntaxException;
-import java.nio.channels.Channels;
-import java.nio.channels.FileLock;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.nio.file.StandardOpenOption;
import java.sql.SQLException;
-import java.time.Duration;
-import java.time.Instant;
import java.util.Arrays;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.atomic.AtomicLong;
import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.getPath;
import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.isNullOrWhitespace;
@@ -64,20 +58,20 @@
* then call the build() method.
*/
public final class DocumentDbSshTunnelServer implements AutoCloseable {
+ public static final String SSH_KNOWN_HOSTS_FILE = "~/.ssh/known_hosts";
+ public static final String STRICT_HOST_KEY_CHECKING = "StrictHostKeyChecking";
+ public static final String HASH_KNOWN_HOSTS = "HashKnownHosts";
+ public static final String SERVER_HOST_KEY = "server_host_key";
+ public static final String YES = "yes";
+ public static final String NO = "no";
+ public static final String LOCALHOST = "localhost";
+ public static final int DEFAULT_DOCUMENTDB_PORT = 27017;
+ public static final int DEFAULT_SSH_PORT = 22;
private static final Logger LOGGER = LoggerFactory.getLogger(DocumentDbSshTunnelServer.class);
- private static final int SERVER_WATCHER_POLL_TIME_MS = 500;
- private static final Object MUTEX = new Object();
- private static final String DOCUMENTDB_SSH_TUNNEL_PATH = "DOCUMENTDB_SSH_TUNNEL_PATH";
- private static final String JAVA_HOME = "java.home";
- private static final String JAVA_CLASS_PATH = "java.class.path";
- private static final String CLASS_PATH_OPTION_NAME = "-cp";
- private static final String BIN_FOLDER_NAME = "bin";
- private static final String JAVA_EXECUTABLE_NAME = "java";
- private static final String SSH_TUNNEL_SERVICE_OPTION_NAME = "--" + DocumentDbMain.SSH_TUNNEL_SERVICE_OPTION_NAME;
- public static final int SERVICE_WAIT_TIMEOUT_SECONDS = 120;
- public static final String FILE_SCHEME = "file";
-
- private final AtomicInteger clientCount = new AtomicInteger(0);
+ public static final int DEFAULT_CLOSE_DELAY_MS = 30000;
+
+ private final Object mutex = new Object();
+ private final AtomicLong clientCount = new AtomicLong(0);
private final String sshUser;
private final String sshHostname;
@@ -86,12 +80,11 @@ public final class DocumentDbSshTunnelServer implements AutoCloseable {
private final boolean sshStrictHostKeyChecking;
private final String sshKnownHostsFile;
private final String remoteHostname;
- private final String propertiesHashString;
- private final AtomicBoolean serverAlive = new AtomicBoolean(false);
- private ServerWatcher serverWatcher = null;
- private Thread serverWatcherThread = null;
+ private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+ private DocumentDbSshTunnelServer.SshPortForwardingSession session = null;
+ private ScheduledFuture> scheduledFuture = null;
+ private long closeDelayMS = DEFAULT_CLOSE_DELAY_MS;
- private volatile int serviceListeningPort = 0;
private DocumentDbSshTunnelServer(final DocumentDbSshTunnelServerBuilder builder) {
this.sshUser = builder.sshUser;
@@ -101,8 +94,6 @@ private DocumentDbSshTunnelServer(final DocumentDbSshTunnelServerBuilder builder
this.sshPrivateKeyPassphrase = builder.sshPrivateKeyPassphrase;
this.sshStrictHostKeyChecking = builder.sshStrictHostKeyChecking;
this.sshKnownHostsFile = builder.sshKnownHostsFile;
- this.propertiesHashString = DocumentDbSshTunnelLock.getHashString(
- sshUser, sshHostname, sshPrivateKeyFile, remoteHostname);
LOGGER.debug("sshUser='{}' sshHostname='{}' sshPrivateKeyFile='{}' remoteHostname'{}"
+ " sshPrivateKeyPassphrase='{}' sshStrictHostKeyChecking='{}' sshKnownHostsFile='{}'",
this.sshUser,
@@ -115,6 +106,175 @@ private DocumentDbSshTunnelServer(final DocumentDbSshTunnelServerBuilder builder
);
}
+ /**
+ * Gets the hash string for the SSH properties provided.
+ *
+ * @param sshUser the username credential for the SSH tunnel.
+ * @param sshHostname the hostname (or IP address) for the SSH tunnel.
+ * @param sshPrivateKeyFile the path to the private key file.
+ *
+ * @return a String value representing the hash of the given properties.
+ */
+ static String getHashString(
+ final String sshUser,
+ final String sshHostname,
+ final String sshPrivateKeyFile,
+ final String remoteHostname) {
+ final String sshPropertiesString = sshUser + "-" + sshHostname + "-" + sshPrivateKeyFile + remoteHostname;
+ return Hashing.sha256()
+ .hashString(sshPropertiesString, StandardCharsets.UTF_8)
+ .toString();
+ }
+
+ /**
+ * Initializes the SSH session and creates a port forwarding tunnel.
+ *
+ * @param connectionProperties the {@link DocumentDbConnectionProperties} connection properties.
+ * @return a {@link Session} session. This session must be closed by calling the
+ * {@link Session#disconnect()} method.
+ * @throws SQLException if unable to create SSH session or create the port forwarding tunnel.
+ */
+ public static SshPortForwardingSession createSshTunnel(
+ final DocumentDbConnectionProperties connectionProperties) throws SQLException {
+ validateSshPrivateKeyFile(connectionProperties);
+
+ LOGGER.debug("Internal SSH tunnel starting.");
+ try {
+ final JSch jSch = new JSch();
+ addIdentity(connectionProperties, jSch);
+ final Session session = createSession(connectionProperties, jSch);
+ connectSession(connectionProperties, jSch, session);
+ final SshPortForwardingSession portForwardingSession = getPortForwardingSession(
+ connectionProperties, session);
+ LOGGER.debug("Internal SSH tunnel started on local port '{}'.",
+ portForwardingSession.getLocalPort());
+ LOGGER.debug("Internal SSH tunnel started.");
+ return portForwardingSession;
+ } catch (Exception e) {
+ throw logException(e);
+ }
+ }
+
+ private static SshPortForwardingSession getPortForwardingSession(
+ final DocumentDbConnectionProperties connectionProperties,
+ final Session session) throws JSchException {
+ final Pair clusterHostAndPort = getHostAndPort(
+ connectionProperties.getHostname(), DEFAULT_DOCUMENTDB_PORT);
+ final int localPort = session.setPortForwardingL(
+ LOCALHOST, 0, clusterHostAndPort.getLeft(), clusterHostAndPort.getRight());
+ return new SshPortForwardingSession(session, localPort);
+ }
+
+ private static Pair getHostAndPort(
+ final String hostname,
+ final int defaultPort) {
+ final String clusterHost;
+ final int clusterPort;
+ final int portSeparatorIndex = hostname.indexOf(':');
+ if (portSeparatorIndex >= 0) {
+ clusterHost = hostname.substring(0, portSeparatorIndex);
+ clusterPort = Integer.parseInt(
+ hostname.substring(portSeparatorIndex + 1));
+ } else {
+ clusterHost = hostname;
+ clusterPort = defaultPort;
+ }
+ return new ImmutablePair<>(clusterHost, clusterPort);
+ }
+
+ private static void connectSession(
+ final DocumentDbConnectionProperties connectionProperties,
+ final JSch jSch,
+ final Session session) throws SQLException {
+ setSecurityConfig(connectionProperties, jSch, session);
+ try {
+ session.connect();
+ } catch (JSchException e) {
+ throw logException(e);
+ }
+ }
+
+ private static void addIdentity(
+ final DocumentDbConnectionProperties connectionProperties,
+ final JSch jSch) throws JSchException {
+ final String privateKeyFileName = getPath(connectionProperties.getSshPrivateKeyFile(),
+ DocumentDbConnectionProperties.getDocumentDbSearchPaths()).toString();
+ LOGGER.debug("SSH private key file resolved to '{}'.", privateKeyFileName);
+ // If passPhrase protected, will need to provide this, too.
+ final String passPhrase = !isNullOrWhitespace(connectionProperties.getSshPrivateKeyPassphrase())
+ ? connectionProperties.getSshPrivateKeyPassphrase()
+ : null;
+ jSch.addIdentity(privateKeyFileName, passPhrase);
+ }
+
+ private static Session createSession(
+ final DocumentDbConnectionProperties connectionProperties,
+ final JSch jSch) throws SQLException {
+ final String sshUsername = connectionProperties.getSshUser();
+ final Pair sshHostAndPort = getHostAndPort(
+ connectionProperties.getSshHostname(), DEFAULT_SSH_PORT);
+ setKnownHostsFile(connectionProperties, jSch);
+ try {
+ return jSch.getSession(sshUsername, sshHostAndPort.getLeft(), sshHostAndPort.getRight());
+ } catch (JSchException e) {
+ throw logException(e);
+ }
+ }
+
+ private static void setSecurityConfig(
+ final DocumentDbConnectionProperties connectionProperties,
+ final JSch jSch,
+ final Session session) {
+ if (!connectionProperties.getSshStrictHostKeyChecking()) {
+ session.setConfig(STRICT_HOST_KEY_CHECKING, NO);
+ return;
+ }
+ setHostKeyType(connectionProperties, jSch, session);
+ }
+
+ private static void setHostKeyType(
+ final DocumentDbConnectionProperties connectionProperties,
+ final JSch jSch, final Session session) {
+ final HostKeyRepository keyRepository = jSch.getHostKeyRepository();
+ final HostKey[] hostKeys = keyRepository.getHostKey();
+ final Pair sshHostAndPort = getHostAndPort(
+ connectionProperties.getSshHostname(), DEFAULT_SSH_PORT);
+ final HostKey hostKey = Arrays.stream(hostKeys)
+ .filter(hk -> hk.getHost().equals(sshHostAndPort.getLeft()))
+ .findFirst().orElse(null);
+ // This will ensure a match between how the host key was hashed in the known_hosts file.
+ final String hostKeyType = (hostKey != null) ? hostKey.getType() : null;
+ // Append the hash algorithm
+ if (hostKeyType != null) {
+ session.setConfig(SERVER_HOST_KEY, session.getConfig(SERVER_HOST_KEY) + "," + hostKeyType);
+ }
+ // The default behaviour of `ssh-keygen` is to hash known hosts keys
+ session.setConfig(HASH_KNOWN_HOSTS, YES);
+ }
+
+ private static void setKnownHostsFile(
+ final DocumentDbConnectionProperties connectionProperties,
+ final JSch jSch) throws SQLException {
+ if (!connectionProperties.getSshStrictHostKeyChecking()) {
+ return;
+ }
+ final String knownHostsFilename;
+ knownHostsFilename = getSshKnownHostsFilename(connectionProperties);
+ try {
+ jSch.setKnownHosts(knownHostsFilename);
+ } catch (JSchException e) {
+ throw logException(e);
+ }
+ }
+
+ private static SQLException logException(final T e) {
+ LOGGER.error(e.getMessage(), e);
+ if (e instanceof SQLException) {
+ return (SQLException) e;
+ }
+ return new SQLException(e.getMessage(), e);
+ }
+
/**
* Gets the SSH tunnel service listening port. A value of zero indicates that the
* SSH tunnel service is not running.
@@ -122,14 +282,18 @@ private DocumentDbSshTunnelServer(final DocumentDbSshTunnelServerBuilder builder
* @return A port number that the SSH tunnel service is listening on.
*/
public int getServiceListeningPort() {
- return serviceListeningPort;
+ return session != null ? session.getLocalPort() : 0;
}
@Override
- public void close() throws Exception {
- synchronized (MUTEX) {
- serviceListeningPort = 0;
- shutdownServerWatcherThread();
+ public void close() {
+ synchronized (mutex) {
+ if (session != null) {
+ LOGGER.debug("Internal SSH Tunnel is stopping.");
+ session.getSession().disconnect();
+ session = null;
+ LOGGER.debug("Internal SSH Tunnel is stopped.");
+ }
}
}
@@ -137,14 +301,18 @@ public void close() throws Exception {
* Adds a client to the reference count for this server. If this is the first client, the server
* ensures that an SSH Tunnel service is started.
*
- * @throws Exception When an error occurs trying to start the SSH Tunnel service.
+ * @throws SQLException When an error occurs trying to start the SSH Tunnel service.
*/
- public void addClient() throws Exception {
- synchronized (MUTEX) {
- if (clientCount.get() == 0) {
- ensureStarted();
- }
+ public void addClient() throws SQLException {
+ // Needs to be synchronized in a single process
+ synchronized (mutex) {
+ cancelScheduledFutureClose();
clientCount.incrementAndGet();
+ if (session != null && session.getLocalPort() != 0) {
+ return;
+ }
+ validateLocalSshFilesExists();
+ session = createSshTunnel(getConnectionProperties());
}
}
@@ -152,59 +320,106 @@ public void addClient() throws Exception {
* Removes a client from the reference count for this server. If the reference count reaches zero, then
* the serve attempt to stop the SSH Tunnel service.
*
- * @throws Exception When an error occur attempting shutdown of the service process.
+ * @throws SQLException When an error occur attempting shutdown of the service process.
*/
- public void removeClient() throws Exception {
- synchronized (MUTEX) {
- if (clientCount.decrementAndGet() == 0) {
- close();
+ public void removeClient() throws SQLException {
+ synchronized (mutex) {
+ // Takes advantage of OR to only decrement if greater than zero.
+ if (clientCount.get() <= 0 || clientCount.decrementAndGet() > 0) {
+ return;
}
+ closeSession();
}
}
- private void shutdownServerWatcherThread() throws Exception {
- if (serverWatcherThread.isAlive()) {
- LOGGER.info("Stopping server watcher thread.");
- serverWatcher.close();
- do {
- serverWatcherThread.join(SERVER_WATCHER_POLL_TIME_MS * 2);
- } while (serverWatcherThread.isAlive());
- final Queue exceptions = serverWatcher.getExceptions();
- if (!exceptions.isEmpty()) {
- for (Exception e : exceptions) {
- LOGGER.error(e.getMessage(), e);
- }
- }
- LOGGER.info("Stopped server watcher thread.");
+ /**
+ * Closes the SSH tunnel session. If a close delay is given, delay the
+ * close until that time has passed.
+ *
+ * @throws SQLException In the case the task is interrupted.
+ */
+ private void closeSession() throws SQLException {
+ cancelScheduledFutureClose();
+ // Delay the close, if indicated.
+ final long delayMS = getCloseDelayMS();
+ if (delayMS <= 0) {
+ close();
} else {
- LOGGER.info("Server watcher thread already stopped.");
+ LOGGER.debug("Close timer is being scheduled.");
+ scheduledFuture = scheduler.schedule(getCloseTimerTask(), delayMS, TimeUnit.MILLISECONDS);
}
}
/**
- * Checks the state of the SSH tunnel service.
+ * Gets the {@link Runnable} task to close the SSH tunnel session.
*
- * @return Returns true if the SSH tunnel service is running.
+ * @return the task to close the SSH tunnel session.
*/
- public boolean isAlive() throws Exception {
- if (serverWatcherThread.isAlive()) {
- // While the watcher thread is running, the status should be pretty accurate.
- return serverAlive.get();
- } else {
- // Can no longer rely on watcher to have an updated status, check synchronously here.
- final Path serverLockPath = DocumentDbSshTunnelLock.getServerLockPath(propertiesHashString);
- try (DocumentDbMultiThreadFileChannel serverChannel = DocumentDbMultiThreadFileChannel.open(
- serverLockPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE)) {
- // NOTE: Server lock will be release when channel is closed.
- final FileLock serverLock = serverChannel.tryLock();
- if (serverLock != null) {
- return false;
+ private Runnable getCloseTimerTask() {
+ return () -> {
+ try {
+ close();
+ } catch (Exception e) {
+ // Ignore exception on close.
+ LOGGER.warn(e.getMessage(), e);
+ }
+ };
+ }
+
+ /**
+ * Cancels the scheduled future to close the SSH tunnel session in the case a new client gets added before
+ * the close occurs.
+ *
+ * @throws SQLException If interrupted during sleep.
+ */
+ private void cancelScheduledFutureClose() throws SQLException {
+ synchronized (mutex) {
+ if (scheduledFuture != null) {
+ LOGGER.debug("Close timer is being cancelled.");
+ while (!scheduledFuture.isDone()) {
+ scheduledFuture.cancel(false);
+ try {
+ TimeUnit.MILLISECONDS.sleep(10);
+ } catch (InterruptedException e) {
+ throw new SQLException(e.getMessage(), e);
+ }
}
}
- return true;
+ scheduledFuture = null;
}
}
+ @VisibleForTesting
+ long getCloseDelayMS() {
+ return closeDelayMS;
+ }
+
+ @VisibleForTesting
+ void setCloseDelayMS(final long closeDelayMS) {
+ this.closeDelayMS = closeDelayMS > 0 ? closeDelayMS : 0;
+ }
+
+ /**
+ * Gets the number of clients using the server.
+ *
+ * @return The number of clients using the server.
+ */
+ @VisibleForTesting
+ long getClientCount() {
+ synchronized (mutex) {
+ return clientCount.get();
+ }
+ }
+
+ /**
+ * Checks the state of the SSH tunnel service.
+ *
+ * @return Returns true if the SSH tunnel service is running.
+ */
+ public boolean isAlive() {
+ return session != null;
+ }
+
/**
* Factory method for the {@link DocumentDbSshTunnelServerBuilder} class.
*
@@ -299,7 +514,7 @@ public DocumentDbSshTunnelServerBuilder sshKnownHostsFile(final String sshKnownH
* @return a new instance of DocumentDbSshTunnelServer.
*/
public DocumentDbSshTunnelServer build() {
- final String hashString = DocumentDbSshTunnelLock.getHashString(
+ final String hashString = getHashString(
this.sshUser,
this.sshHostname,
this.sshPrivateKeyFile,
@@ -313,273 +528,6 @@ public DocumentDbSshTunnelServer build() {
}
}
- /**
- * Ensure the service is started.
- */
- void ensureStarted() throws Exception {
- maybeStart();
- }
-
- // Needs to be synchronized in a single process
- private void maybeStart() throws Exception {
- synchronized (MUTEX) {
- if (serviceListeningPort != 0) {
- return;
- }
- final AtomicReference exception = new AtomicReference<>(null);
- DocumentDbSshTunnelLock.runInGlobalLock(
- propertiesHashString,
- () -> maybeStartServerHandleException(exception));
- if (exception.get() != null) {
- throw exception.get();
- }
- }
- }
-
- private Exception maybeStartServerHandleException(final AtomicReference exception) {
- try {
- maybeStartServer();
- return null;
- } catch (Exception e) {
- exception.set(e);
- return e;
- }
- }
-
- private void maybeStartServer() throws Exception {
- final Path serverLockPath = DocumentDbSshTunnelLock.getServerLockPath(propertiesHashString);
- try (DocumentDbMultiThreadFileChannel serverChannel = DocumentDbMultiThreadFileChannel.open(
- serverLockPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE)) {
- // NOTE: Server lock will be release when channel is closed.
- FileLock serverLock = serverChannel.tryLock();
- if (serverLock != null) {
- validateLocalSshFilesExists();
-
- // This indicates that the SSH tunnel service does not have a lock.
- // So startup the SSH tunnel service and read the listening port.
- final Path startupLockPath = DocumentDbSshTunnelLock.getStartupLockPath(propertiesHashString);
- final Path portLockPath = DocumentDbSshTunnelLock.getPortLockPath(propertiesHashString);
- DocumentDbSshTunnelLock.deleteStartupAndPortLockFiles(startupLockPath, portLockPath);
-
- // Release the server lock file, which is safe since we're in the global lock.
- if (serverLock.isValid()) {
- serverLock.close();
- }
- // Start the service process
- final Process process = startSshTunnelServiceProcess();
- // Read the listening port
- waitForStartupAndReadPort(startupLockPath, process);
- // Wait for the service to lock the server lock fle.
- final Instant timeoutTime = Instant.now().plus(Duration.ofSeconds(SERVICE_WAIT_TIMEOUT_SECONDS));
- do {
- serverLock = serverChannel.tryLock();
- if (serverLock != null && serverLock.isValid()) {
- serverLock.close();
- // Ensure we don't wait forever.
- throwIfTimeout(timeoutTime, "Timeout waiting for service to acquire server lock.");
- TimeUnit.MILLISECONDS.sleep(SERVER_WATCHER_POLL_TIME_MS);
- }
- } while (serverLock != null);
-
- // Now it's safe to start the watcher thread.
- startServerWatcherThread();
- } else {
- // This indicates that the SSH tunnel service does have a lock.
- // So just read the listening port.
- LOGGER.info("Server already running.");
- readSshPortFromFile();
- // Now it's safe to start the watcher thread.
- startServerWatcherThread();
- }
- }
- }
-
- private static void throwIfTimeout(final Instant timeoutTime, final String message) throws SQLException {
- if (Instant.now().isAfter(timeoutTime)) {
- throw SqlError.createSQLException(LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.SSH_TUNNEL_ERROR,
- message);
- }
- }
-
- private void startServerWatcherThread() {
- serverWatcher = new ServerWatcher(propertiesHashString, serverAlive);
- serverWatcherThread = new Thread(serverWatcher);
- serverWatcherThread.setDaemon(true);
- serverWatcherThread.start();
- }
-
- private void waitForStartupAndReadPort(final Path startupLockPath, final Process process) throws Exception {
- final int pollTimeMS = 100;
- while (!Files.exists(startupLockPath)) {
- TimeUnit.MILLISECONDS.sleep(pollTimeMS);
- }
- try (DocumentDbMultiThreadFileChannel startupChannel = DocumentDbMultiThreadFileChannel.open(
- startupLockPath, StandardOpenOption.WRITE, StandardOpenOption.READ)) {
- FileLock startupLock;
- LOGGER.info("Waiting for server to unlock Startup lock file.");
- final Instant timeoutTime = Instant.now().plus(Duration.ofSeconds(SERVICE_WAIT_TIMEOUT_SECONDS));
- do {
-
- startupLock = startupChannel.tryLock();
- if (startupLock == null) {
- throwIfProcessHasExited(process);
- // Ensure we don't wait forever.
- throwIfTimeout(timeoutTime, "Timeout waiting for service to release Startup lock.");
- TimeUnit.MILLISECONDS.sleep(pollTimeMS);
- }
- } while (startupLock == null);
- LOGGER.info("Server has unlocked Startup lock file.");
- LOGGER.info("Reading Startup lock file.");
-
- try (InputStream inputStream = Channels.newInputStream(startupChannel.getFileChannel());
- InputStreamReader streamReader = new InputStreamReader(inputStream, StandardCharsets.UTF_8);
- BufferedReader reader = new BufferedReader(streamReader)) {
- final StringBuilder exceptionMessage = new StringBuilder();
- boolean isFirstLine = true;
- String line;
- while ((line = reader.readLine()) != null && !DocumentDbConnectionProperties.isNullOrWhitespace(line)) {
- if (!isFirstLine) {
- exceptionMessage.append(System.lineSeparator());
- } else {
- isFirstLine = false;
- }
- exceptionMessage.append(line);
- }
- if (exceptionMessage.length() > 0) {
- exceptionMessage.insert(0, "Server exception detected: '").append("'");
- throw SqlError.createSQLException(
- LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.SSH_TUNNEL_ERROR,
- exceptionMessage.toString());
- }
- LOGGER.info("Finished reading Startup lock file.");
- }
- LOGGER.info("Reading local port number from file.");
- readSshPortFromFile();
- }
- }
-
- private static void throwIfProcessHasExited(final Process process) throws InterruptedException, SQLException {
- synchronized (process) {
- if (process.waitFor(1, TimeUnit.MILLISECONDS)) {
- throw SqlError.createSQLException(LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.SSH_TUNNEL_ERROR,
- "Service has unexpected exited.");
- }
- }
- }
-
- private Process startSshTunnelServiceProcess()
- throws IOException, SQLException, URISyntaxException {
- final List command = getSshTunnelCommand();
- final ProcessBuilder builder = new ProcessBuilder(command);
- return builder.inheritIO().start();
- }
-
- private List getSshTunnelCommand() throws SQLException, URISyntaxException {
- final List command = new LinkedList<>();
- final String docDbSshTunnelPathString = System.getenv(DOCUMENTDB_SSH_TUNNEL_PATH);
- if (docDbSshTunnelPathString != null) {
- // NOTE: This is the entry point for the ODBC driver to provide a (full) path to the executable.
- // It is assumed that we will still provide all the arguments to the executable.
- // E.g., on Windows:
- // DOCUMENTDB_SSH_TUNNEL_PATH=C:\Program Files\documentdb-ssh-tunnel-service\documentdb-ssh-tunnel-service.exe
- final Path docDbSshTunnelPath = Paths.get(docDbSshTunnelPathString);
- if (!Files.isExecutable(docDbSshTunnelPath) || Files.isDirectory(docDbSshTunnelPath)) {
- throw SqlError.createSQLException(
- LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.SSH_TUNNEL_PATH_NOT_FOUND,
- docDbSshTunnelPathString);
- }
-
- command.add(docDbSshTunnelPath.toAbsolutePath().toString());
- command.add(SSH_TUNNEL_SERVICE_OPTION_NAME);
- command.add(getSshPropertiesString());
- } else {
- final String className = DocumentDbMain.class.getName();
- final String sshConnectionProperties = getSshPropertiesString();
- command.addAll(getJavaCommand(className, SSH_TUNNEL_SERVICE_OPTION_NAME, sshConnectionProperties));
- }
- return command;
- }
-
- /**
- * Gets the command line parameters for invoking the Java executable in the JAVA_HOME.
- *
- * @param className the name of the class that contains the 'main()' method.
- * @param arguments the arguments to the main method.
- * @return a list of arguments to pass to {@link ProcessBuilder}.
- * @throws SQLException when the path to the java executable cannot be resolved.
- */
- public static List getJavaCommand(final String className, final String... arguments)
- throws SQLException, URISyntaxException {
- final String javaBinFilePath = getJavaBinFilePath();
- final String combinedClassPath = getCombinedClassPath();
-
- final List command = new LinkedList<>();
- command.add(javaBinFilePath);
- command.add(CLASS_PATH_OPTION_NAME);
- command.add(combinedClassPath);
- command.add(className);
- command.addAll(Arrays.asList(arguments));
- return command;
- }
-
- private static String getJavaBinFilePath() throws SQLException {
- // Check that the java command executable is available relative to the
- // JAVA_HOME environment variable.
- final String javaHome = getJavaHome();
- final String javaBinFilePath = Paths.get(javaHome, BIN_FOLDER_NAME, JAVA_EXECUTABLE_NAME).toString();
- final boolean isOsWindows = org.apache.commons.lang3.SystemUtils.IS_OS_WINDOWS;
- final Path javaBinPath = Paths.get(javaBinFilePath + (isOsWindows ? ".exe" : ""));
- if (!Files.exists(javaBinPath) || !Files.isExecutable(javaBinPath)) {
- throw SqlError.createSQLException(
- LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.MISSING_JAVA_BIN,
- javaBinPath.toString());
- }
- return javaBinFilePath;
- }
-
- private static String getJavaHome() throws SQLException {
- final String javaHome = System.getProperty(JAVA_HOME);
- if (isNullOrWhitespace(javaHome) || !Files.exists(Paths.get(javaHome))) {
- throw SqlError.createSQLException(
- LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.MISSING_JAVA_HOME);
- }
- return javaHome;
- }
-
- private static String getCombinedClassPath() throws URISyntaxException {
- URI currentClassPathUri = DocumentDbSshTunnelServer.class
- .getProtectionDomain().getCodeSource().getLocation().toURI();
- final String schemeSpecificPart = currentClassPathUri.getSchemeSpecificPart();
- if (currentClassPathUri.getScheme().equalsIgnoreCase(FILE_SCHEME)
- && !isNullOrWhitespace(schemeSpecificPart)) {
- // Ensure only 1 slash at beginning.
- final String startsWithSlashExpression = "^/+";
- currentClassPathUri = new URI(currentClassPathUri.getScheme()
- + ":/"
- + schemeSpecificPart.replaceAll(startsWithSlashExpression, ""));
-
- }
- final String currentClassCodeSourcePath = new File(currentClassPathUri).getAbsolutePath();
- return currentClassCodeSourcePath + ";" + System.getProperty(JAVA_CLASS_PATH);
- }
-
- private @NonNull String getSshPropertiesString() {
- final DocumentDbConnectionProperties connectionProperties = getConnectionProperties();
- return DocumentDbConnectionProperties.DOCUMENT_DB_SCHEME + connectionProperties.buildSshConnectionString();
- }
-
@NonNull
private DocumentDbConnectionProperties getConnectionProperties() {
final DocumentDbConnectionProperties connectionProperties = new DocumentDbConnectionProperties();
@@ -622,7 +570,7 @@ static String getSshKnownHostsFilename(final DocumentDbConnectionProperties conn
validateSshKnownHostsFile(connectionProperties, knownHostsPath);
knowHostsFilename = knownHostsPath.toString();
} else {
- knowHostsFilename = getPath(DocumentDbSshTunnelService.SSH_KNOWN_HOSTS_FILE).toString();
+ knowHostsFilename = getPath(SSH_KNOWN_HOSTS_FILE).toString();
}
return knowHostsFilename;
}
@@ -639,100 +587,19 @@ private static void validateSshKnownHostsFile(
}
}
- private void readSshPortFromFile() throws IOException, SQLException {
- final Path portLockPath = DocumentDbSshTunnelLock.getPortLockPath(propertiesHashString);
- final List lines = Files.readAllLines(portLockPath, StandardCharsets.UTF_8);
- int port = 0;
- for (String line : lines) {
- if (!line.trim().isEmpty()) {
- port = Integer.parseInt(line.trim());
- if (port > 0) {
- break;
- }
- }
- }
- if (port <= 0) {
- serviceListeningPort = 0;
- throw SqlError.createSQLException(
- LOGGER,
- SqlState.CONNECTION_EXCEPTION,
- SqlError.SSH_TUNNEL_ERROR,
- "Unable to read valid listening port for SSH Tunnel service.");
- }
- serviceListeningPort = port;
- LOGGER.info("SHH tunnel service listening on port: " + serviceListeningPort);
- }
-
- private enum ServerWatcherState {
- INITIALIZED,
- RUNNING,
- INTERRUPTED,
- COMPLETED,
- ERROR,
- }
-
- private static class ServerWatcher implements Runnable, AutoCloseable {
-
- private volatile ServerWatcherState state = ServerWatcherState.INITIALIZED;
- private final Queue exceptions = new ConcurrentLinkedDeque<>();
- private final String propertiesHashString;
- private final AtomicBoolean serverAlive;
-
- ServerWatcher(final String propertiesHashString, final AtomicBoolean serverAlive) {
- this.propertiesHashString = propertiesHashString;
- this.serverAlive = serverAlive;
- }
-
+ /**
+ * Container for the SSH port forwarding tunnel session.
+ */
+ @Getter
+ @AllArgsConstructor
+ static class SshPortForwardingSession {
/**
- * Gets the queue of exceptions.
- *
- * @return a queue of exceptions.
+ * Gets the SSH session.
*/
- public Queue getExceptions() {
- return exceptions;
- }
-
- @Override
- public void run() {
- try {
- state = ServerWatcherState.RUNNING;
- do {
- DocumentDbSshTunnelLock.runInGlobalLock(propertiesHashString, this::checkForServerLock);
- if (state == ServerWatcherState.RUNNING) {
- TimeUnit.MILLISECONDS.sleep(SERVER_WATCHER_POLL_TIME_MS);
- }
- } while (state == ServerWatcherState.RUNNING);
- } catch (Exception e) {
- exceptions.add(e);
- }
- }
-
- @Override
- public void close() throws Exception {
- state = ServerWatcherState.INTERRUPTED;
- }
-
- private Exception checkForServerLock() {
- Exception exception = null;
- final Path serverLockPath = DocumentDbSshTunnelLock.getServerLockPath(propertiesHashString);
- try (DocumentDbMultiThreadFileChannel serverChannel = DocumentDbMultiThreadFileChannel.open(
- serverLockPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE)) {
- // NOTE: Server lock will be release when channel is closed.
- final FileLock serverLock = serverChannel.tryLock();
- if (serverLock != null) {
- // Server abandoned lock. Set to false if the previous value was true.
- serverAlive.compareAndSet(true, false);
- state = ServerWatcherState.COMPLETED;
- } else {
- // Server is still alive. Set to true if the previous value was false.
- serverAlive.compareAndSet(false, true);
- }
- } catch (Exception e) {
- exception = e;
- exceptions.add(e);
- state = ServerWatcherState.ERROR;
- }
- return exception;
- }
+ private final Session session;
+ /**
+ * Gets the local port for the port forwarding tunnel.
+ */
+ private final int localPort;
}
}
diff --git a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelService.java b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelService.java
deleted file mode 100644
index 97b8f06b..00000000
--- a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelService.java
+++ /dev/null
@@ -1,560 +0,0 @@
-/*
- * Copyright <2022> Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License").
- * You may not use this file except in compliance with the License.
- * A copy of the License is located at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * or in the "license" file accompanying this file. This file is distributed
- * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- *
- */
-
-package software.amazon.documentdb.jdbc.sshtunnel;
-
-import com.jcraft.jsch.HostKey;
-import com.jcraft.jsch.HostKeyRepository;
-import com.jcraft.jsch.JSch;
-import com.jcraft.jsch.JSchException;
-import com.jcraft.jsch.Session;
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.SneakyThrows;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import org.apache.commons.lang3.tuple.Pair;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties;
-
-import java.io.BufferedWriter;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.nio.channels.Channels;
-import java.nio.channels.FileLock;
-import java.nio.charset.StandardCharsets;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.StandardOpenOption;
-import java.sql.SQLException;
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.Queue;
-import java.util.concurrent.ConcurrentLinkedDeque;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.ValidationType.SSH_TUNNEL;
-import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.getPath;
-import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.isNullOrWhitespace;
-
-/**
- * The DocumentDbSshTunnelService class provide a runnable service to host an SSH Tunnel.
- * It monitors the running clients and exits when there are no more active clients.
- */
-public class DocumentDbSshTunnelService implements AutoCloseable, Runnable {
- public static final String SSH_KNOWN_HOSTS_FILE = "~/.ssh/known_hosts";
- public static final String STRICT_HOST_KEY_CHECKING = "StrictHostKeyChecking";
- public static final String HASH_KNOWN_HOSTS = "HashKnownHosts";
- public static final String SERVER_HOST_KEY = "server_host_key";
- public static final String YES = "yes";
- public static final String NO = "no";
- public static final String LOCALHOST = "localhost";
- public static final int DEFAULT_DOCUMENTDB_PORT = 27017;
- public static final int DEFAULT_SSH_PORT = 22;
- private static final Logger LOGGER = LoggerFactory.getLogger(DocumentDbSshTunnelService.class);
- public static final int CLIENT_WATCH_POLL_TIME = 500;
- private final DocumentDbConnectionProperties connectionProperties;
- private final String sshPropertiesHashString;
- private volatile boolean completed = false;
- private volatile boolean interrupted = false;
- private final ConcurrentLinkedDeque exceptions = new ConcurrentLinkedDeque<>();
-
- /**
- * Constructs a new instance of DocumentDbSshTunnelService.
- *
- * @param connectionString the SSH tunnel connection string properties.
- * @throws SQLException thrown if unable to parse connection string.
- */
- public DocumentDbSshTunnelService(final String connectionString) throws SQLException {
- connectionProperties = DocumentDbConnectionProperties.getPropertiesFromConnectionString(
- connectionString, SSH_TUNNEL);
- sshPropertiesHashString = DocumentDbSshTunnelLock.getHashString(
- connectionProperties.getSshUser(),
- connectionProperties.getSshHostname(),
- connectionProperties.getSshPrivateKeyFile(),
- connectionProperties.getHostname());
- }
-
- @Override
- @SneakyThrows
- public void close() {
- interrupted = true;
- }
-
- /**
- * Runs the SSH tunnel and polls for client lock files.
- * When all the client lock files are gone or unlocked, then this method stops and
- * cleans up any resources.
- */
- @Override
- public void run() {
- SshPortForwardingSession session = null;
- DocumentDbMultiThreadFileChannel serverChannel = null;
- FileLock serverLock = null;
-
- while (!interrupted && !completed) {
- try {
- LOGGER.debug("SSH Tunnel service starting.");
- session = performSshTunnelSessionStartup();
- final Map.Entry lock = acquireServerLock();
- serverChannel = lock.getKey();
- serverLock = lock.getValue();
- LOGGER.debug("SSH Tunnel service started.");
-
- // launch thread and wait for clients to terminate.
- waitForClients(serverLock);
- } catch (InterruptedException e) {
- logException(e);
- interrupted = true;
- } catch (Exception e) {
- exceptions.add(logException(e));
- } finally {
- try {
- LOGGER.debug("SSH Tunnel service stopping.");
- cleanupResourcesInGlobalLock(session, serverChannel, serverLock);
- } catch (Exception e) {
- exceptions.add(logException(e));
- }
- completed = true;
- }
- }
- LOGGER.debug("SSH Tunnel service stopped.");
- }
-
- private void cleanupResourcesInGlobalLock(
- final SshPortForwardingSession session,
- final DocumentDbMultiThreadFileChannel serverChannel,
- final FileLock serverLock) throws Exception {
- DocumentDbSshTunnelLock.runInGlobalLock(
- sshPropertiesHashString,
- () -> closeResources(session, serverChannel, serverLock, exceptions));
- }
-
- private Queue closeResources(
- final SshPortForwardingSession session,
- final DocumentDbMultiThreadFileChannel serverChannel,
- final FileLock serverLock,
- final Queue exceptions) {
- try {
- if (serverLock != null && serverLock.isValid()) {
- serverLock.close();
- }
- if (serverChannel != null && serverChannel.isOpen()) {
- serverChannel.close();
- }
- if (session != null) {
- session.getSession().disconnect();
- }
- final Path portLockPath = DocumentDbSshTunnelLock.getPortLockPath(sshPropertiesHashString);
- final Path startupLockPath = DocumentDbSshTunnelLock.getStartupLockPath(sshPropertiesHashString);
- Files.deleteIfExists(portLockPath);
- Files.deleteIfExists(startupLockPath);
- } catch (Exception e) {
- exceptions.add(logException(e));
- }
- return exceptions;
- }
-
- /**
- * Gets the SSH tunnel properties hash string.
- *
- * @return a {@link String} representing the SSH tunnel properties hash string.
- */
- public String getSshPropertiesHashString() {
- return sshPropertiesHashString;
- }
-
- private void
- waitForClients(final FileLock serverLock) throws InterruptedException {
- ClientWatcher clientWatcher = null;
- try {
- clientWatcher = new ClientWatcher(serverLock, sshPropertiesHashString);
- final Thread clientWatcherThread = new Thread(clientWatcher);
- clientWatcherThread.setDaemon(true);
- clientWatcherThread.start();
- do {
- clientWatcherThread.join(1000);
- } while (clientWatcherThread.isAlive() && !interrupted);
- } finally {
- if (clientWatcher != null) {
- exceptions.addAll(clientWatcher.getExceptions());
- }
- }
- }
-
- /**
- * Closes (and unlocks) the server lock if not already unlocked.
- */
- private static Exception closeServerLock(final FileLock serverLock) {
- Exception result = null;
- if (serverLock != null && serverLock.isValid()) {
- try {
- serverLock.close();
- } catch (IOException e) {
- result = logException(e);
- }
- }
- return result;
- }
-
-
- /**
- * Initializes the SSH session and creates a port forwarding tunnel.
- *
- * @param connectionProperties the {@link DocumentDbConnectionProperties} connection properties.
- * @return a {@link Session} session. This session must be closed by calling the
- * {@link Session#disconnect()} method.
- * @throws SQLException if unable to create SSH session or create the port forwarding tunnel.
- */
- public static SshPortForwardingSession createSshTunnel(
- final DocumentDbConnectionProperties connectionProperties) throws SQLException {
- DocumentDbSshTunnelServer.validateSshPrivateKeyFile(connectionProperties);
-
- LOGGER.debug("Internal SSH tunnel starting.");
- try {
- final JSch jSch = new JSch();
- addIdentity(connectionProperties, jSch);
- final Session session = createSession(connectionProperties, jSch);
- connectSession(connectionProperties, jSch, session);
- final SshPortForwardingSession portForwardingSession = getPortForwardingSession(
- connectionProperties, session);
- LOGGER.debug("Internal SSH tunnel started on local port '{}'.",
- portForwardingSession.getLocalPort());
- return portForwardingSession;
- } catch (Exception e) {
- throw logException(e);
- }
- }
-
- private static SshPortForwardingSession getPortForwardingSession(
- final DocumentDbConnectionProperties connectionProperties,
- final Session session) throws JSchException {
- final Pair clusterHostAndPort = getHostAndPort(
- connectionProperties.getHostname(), DEFAULT_DOCUMENTDB_PORT);
- final int localPort = session.setPortForwardingL(
- LOCALHOST, 0, clusterHostAndPort.getLeft(), clusterHostAndPort.getRight());
- return new SshPortForwardingSession(session, localPort);
- }
-
- private static Pair getHostAndPort(
- final String hostname,
- final int defaultPort) {
- final String clusterHost;
- final int clusterPort;
- final int portSeparatorIndex = hostname.indexOf(':');
- if (portSeparatorIndex >= 0) {
- clusterHost = hostname.substring(0, portSeparatorIndex);
- clusterPort = Integer.parseInt(
- hostname.substring(portSeparatorIndex + 1));
- } else {
- clusterHost = hostname;
- clusterPort = defaultPort;
- }
- return new ImmutablePair<>(clusterHost, clusterPort);
- }
-
- private static void connectSession(
- final DocumentDbConnectionProperties connectionProperties,
- final JSch jSch,
- final Session session) throws SQLException {
- setSecurityConfig(connectionProperties, jSch, session);
- try {
- session.connect();
- } catch (JSchException e) {
- throw logException(e);
- }
- }
-
- private static void addIdentity(
- final DocumentDbConnectionProperties connectionProperties,
- final JSch jSch) throws JSchException {
- final String privateKeyFileName = getPath(connectionProperties.getSshPrivateKeyFile(),
- DocumentDbConnectionProperties.getDocumentDbSearchPaths()).toString();
- LOGGER.debug("SSH private key file resolved to '{}'.", privateKeyFileName);
- // If passPhrase protected, will need to provide this, too.
- final String passPhrase = !isNullOrWhitespace(connectionProperties.getSshPrivateKeyPassphrase())
- ? connectionProperties.getSshPrivateKeyPassphrase()
- : null;
- jSch.addIdentity(privateKeyFileName, passPhrase);
- }
-
- private static Session createSession(
- final DocumentDbConnectionProperties connectionProperties,
- final JSch jSch) throws SQLException {
- final String sshUsername = connectionProperties.getSshUser();
- final Pair sshHostAndPort = getHostAndPort(
- connectionProperties.getSshHostname(), DEFAULT_SSH_PORT);
- setKnownHostsFile(connectionProperties, jSch);
- try {
- return jSch.getSession(sshUsername, sshHostAndPort.getLeft(), sshHostAndPort.getRight());
- } catch (JSchException e) {
- throw logException(e);
- }
- }
-
- private static void setSecurityConfig(
- final DocumentDbConnectionProperties connectionProperties,
- final JSch jSch,
- final Session session) {
- if (!connectionProperties.getSshStrictHostKeyChecking()) {
- session.setConfig(STRICT_HOST_KEY_CHECKING, NO);
- return;
- }
- setHostKeyType(connectionProperties, jSch, session);
- }
-
- private static void setHostKeyType(
- final DocumentDbConnectionProperties connectionProperties,
- final JSch jSch, final Session session) {
- final HostKeyRepository keyRepository = jSch.getHostKeyRepository();
- final HostKey[] hostKeys = keyRepository.getHostKey();
- final Pair sshHostAndPort = getHostAndPort(
- connectionProperties.getSshHostname(), DEFAULT_SSH_PORT);
- final HostKey hostKey = Arrays.stream(hostKeys)
- .filter(hk -> hk.getHost().equals(sshHostAndPort.getLeft()))
- .findFirst().orElse(null);
- // This will ensure a match between how the host key was hashed in the known_hosts file.
- final String hostKeyType = (hostKey != null) ? hostKey.getType() : null;
- // Append the hash algorithm
- if (hostKeyType != null) {
- session.setConfig(SERVER_HOST_KEY, session.getConfig(SERVER_HOST_KEY) + "," + hostKeyType);
- }
- // The default behaviour of `ssh-keygen` is to hash known hosts keys
- session.setConfig(HASH_KNOWN_HOSTS, YES);
- }
-
- private static void setKnownHostsFile(
- final DocumentDbConnectionProperties connectionProperties,
- final JSch jSch) throws SQLException {
- if (!connectionProperties.getSshStrictHostKeyChecking()) {
- return;
- }
- final String knowHostsFilename;
- knowHostsFilename = DocumentDbSshTunnelServer.getSshKnownHostsFilename(connectionProperties);
- try {
- jSch.setKnownHosts(knowHostsFilename);
- } catch (JSchException e) {
- throw logException(e);
- }
- }
- private Map.Entry acquireServerLock() throws IOException, InterruptedException {
- final Path serverLockPath = DocumentDbSshTunnelLock.getServerLockPath(sshPropertiesHashString);
- final Path parentPath = serverLockPath.getParent();
- assert parentPath != null;
- Files.createDirectories(parentPath);
- final DocumentDbMultiThreadFileChannel serverChannel = DocumentDbMultiThreadFileChannel.open(
- serverLockPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE);
- FileLock serverLock;
- final int pollTimeMS = 100;
- while ((serverLock = serverChannel.tryLock()) == null) {
- TimeUnit.MILLISECONDS.sleep(pollTimeMS);
- }
- return new AbstractMap.SimpleImmutableEntry<>(serverChannel, serverLock);
- }
-
- private SshPortForwardingSession performSshTunnelSessionStartup()
- throws Exception {
- if (!connectionProperties.enableSshTunnel()) {
- throw new UnsupportedOperationException(
- "Unable to create SSH tunnel session. Invalid properties provided.");
- }
- final SshPortForwardingSession session;
- final Path startupLockPath = DocumentDbSshTunnelLock.getStartupLockPath(sshPropertiesHashString);
- final Path parentPath = startupLockPath.getParent();
- assert parentPath != null;
- Files.createDirectories(parentPath);
- try (DocumentDbMultiThreadFileChannel startupChannel = DocumentDbMultiThreadFileChannel.open(
- startupLockPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE);
- BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
- Channels.newOutputStream(startupChannel.getFileChannel()),
- StandardCharsets.UTF_8));
- FileLock ignored = startupChannel.lock()) {
- try {
- session = createSshTunnel(connectionProperties);
- } catch (Exception e) {
- logException(e);
- writer.write(e.toString());
- throw e;
- }
- writeSssTunnelPort(session);
- }
- return session;
- }
-
- private void writeSssTunnelPort(final SshPortForwardingSession session) throws IOException {
- final Path portLockPath = DocumentDbSshTunnelLock.getPortLockPath(sshPropertiesHashString);
- try (FileOutputStream outputStream = new FileOutputStream(portLockPath.toFile());
- BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8));
- FileLock ignored = outputStream.getChannel().lock()) {
- writer.write(String.format("%d%n", session.getLocalPort()));
- }
- }
-
- /**
- * Gets a copy of the list of exceptions raised while the service is running.
- *
- * @return a list of exceptions raised while the service is running.
- */
- public List getExceptions() {
- return Collections.unmodifiableList(new ArrayList<>(exceptions));
- }
-
- /**
- * The ClientWatcher class implements the {@link Runnable} interface. When run,
- * it monitors the clients lock folder for client lock files. When no more locked
- * files exist, the run method ends. Before exiting, it will release the passed server lock file
- * inside the global lock to ensure that there will not be a race condition with new clients trying
- * to start up.
- */
- private static class ClientWatcher implements Runnable {
- private enum ThreadState {
- UNKNOWN,
- RUNNING,
- INTERRUPTED,
- EXITING,
- }
-
- private final ConcurrentLinkedDeque exceptions = new ConcurrentLinkedDeque<>();
- private final FileLock serverLock;
- private final String sshPropertiesHashString;
-
- public ClientWatcher(final FileLock serverLock, final String sshPropertiesHashString) {
- this.serverLock = serverLock;
- this.sshPropertiesHashString = sshPropertiesHashString;
- }
-
- @Override
- public void run() {
- ThreadState state = ThreadState.RUNNING;
- try {
- final AtomicInteger clientCount = new AtomicInteger();
- do {
- clientCount.set(0);
- DocumentDbSshTunnelLock.runInGlobalLock(
- sshPropertiesHashString,
- () -> checkAndHandleClientLocks(clientCount, sshPropertiesHashString, serverLock));
- if (clientCount.get() > 0) {
- TimeUnit.MILLISECONDS.sleep(CLIENT_WATCH_POLL_TIME);
- } else {
- state = ThreadState.EXITING;
- }
- } while (state == ThreadState.RUNNING);
- } catch (Exception e) {
- exceptions.add(logException(e));
- } finally {
- try {
- final Exception localException = DocumentDbSshTunnelLock.runInGlobalLock(
- sshPropertiesHashString, () -> closeServerLock(serverLock));
- if (localException != null) {
- exceptions.add(localException);
- }
- } catch (Exception e) {
- exceptions.add(logException(e));
- }
- }
- }
-
- /**
- * Checks all the client lock files. If any client lock files can be locked, then
- * client has abandoned the file, and it can be deleted. File locks that cannot be attained
- * must be considered alive. If there are no locked files, then we can safely unlock and close the server
- * lock.
- *
- * @param clientCount the number of alive clients with locked files.
- */
- @SneakyThrows
- private static Exception checkAndHandleClientLocks(
- final AtomicInteger clientCount, final String sshPropertiesHashString, final FileLock serverLock) {
- Exception result = null;
- final Path clientsFolderPath = DocumentDbSshTunnelLock.getClientsFolderPath(sshPropertiesHashString);
- Files.createDirectories(clientsFolderPath);
- try (Stream files = Files.list(clientsFolderPath)) {
- for (Path filePath : files.collect(Collectors.toList())) {
- final Exception exception = checkClientLock(clientCount, filePath);
- if (exception != null) {
- return exception;
- }
- }
- }
- if (clientCount.get() == 0) {
- result = closeServerLock(serverLock);
- }
- return result;
- }
-
- /**
- * Checks the client lock for one file path.
- * ASSUMPTION: this method is called from withing a global lock.
- *
- * @param clientCount the number of active clients with locked files.
- * @param filePath the path to the client lock file.
- */
- private static Exception checkClientLock(final AtomicInteger clientCount, final Path filePath) {
- Exception result = null;
- try (DocumentDbMultiThreadFileChannel fileChannel = DocumentDbMultiThreadFileChannel.open(
- filePath, StandardOpenOption.WRITE)) {
- final FileLock fileLock = fileChannel.tryLock();
- if (fileLock == null) {
- clientCount.getAndIncrement();
- } else {
- fileLock.close();
- Files.deleteIfExists(filePath);
- }
- } catch (Exception e) {
- result = logException(e);
- }
- return result;
- }
-
- @NonNull
- public ConcurrentLinkedDeque getExceptions() {
- return exceptions;
- }
- }
-
- private static SQLException logException(final T e) {
- LOGGER.error(e.getMessage(), e);
- if (e instanceof SQLException) {
- return (SQLException) e;
- }
- return new SQLException(e.getMessage(), e);
- }
-
- /**
- * Container for the SSH port forwarding tunnel session.
- */
- @Getter
- @AllArgsConstructor
- static class SshPortForwardingSession {
- /**
- * Gets the SSH session.
- */
- private final Session session;
- /**
- * Gets the local port for the port forwarding tunnel.
- */
- private final int localPort;
- }
-}
diff --git a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelTestClientRunner.java b/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelTestClientRunner.java
deleted file mode 100644
index c2125b7a..00000000
--- a/src/main/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelTestClientRunner.java
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
- * Copyright <2022> Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License").
- * You may not use this file except in compliance with the License.
- * A copy of the License is located at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * or in the "license" file accompanying this file. This file is distributed
- * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- *
- */
-
-package software.amazon.documentdb.jdbc.sshtunnel;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.lang.management.ManagementFactory;
-import java.security.SecureRandom;
-import java.sql.Connection;
-import java.sql.DriverManager;
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
-
-/**
- * This class is used to test the internal SSH tunnel by
- * creating a number of connections on a number of threads.
- *
- * It is for testing purposes only.
- */
-class DocumentDbSshTunnelTestClientRunner {
- private static final Logger LOGGER = LoggerFactory.getLogger(DocumentDbSshTunnelTestClientRunner.class);
- private static final String PROCESS_NAME = ManagementFactory.getRuntimeMXBean().getName();
- private static String connectionString;
- private static int clientRunTime;
-
- /**
- * Main entry point to client runner test application.
- *
- * @param args the command line arguments
- */
- public static void main(final String[] args) {
-
- boolean hasExceptions = false;
- if (args.length < 1) {
- LOGGER.error("Unexpected number of arguments. Required: connectionString [maxNumberOfClients]");
- System.exit(-1);
- }
-
- connectionString = args[0];
- final int maxNumberOfClients = args.length > 1 ? Integer.parseInt(args[1]) : 1;
- clientRunTime = args.length > 2 ? Integer.parseInt(args[2]) : 30;
-
- final List> runners = new ArrayList<>();
-
- try {
- for (int index = 0; index < maxNumberOfClients; index++) {
- startConnectionRunner(runners, index);
- }
- } catch (Exception e) {
- hasExceptions = true;
- writeException(e);
- } finally {
- for (Map.Entry entry : runners) {
- try {
- waitForConnectionRunner(entry);
- } catch (Exception e) {
- hasExceptions = true;
- writeException(e);
- }
- }
- runners.clear();
- }
- System.exit(hasExceptions ? 1 : 0);
- }
-
- private static void waitForConnectionRunner(final Map.Entry entry) throws InterruptedException {
- LOGGER.debug(PROCESS_NAME + ": Stopping entry");
- entry.getValue().join();
- if (entry.getKey().exception != null) {
- LOGGER.error("Connection failed", entry.getKey().exception);
- }
- LOGGER.debug(PROCESS_NAME + ": Stopped entry");
- }
-
- private static void startConnectionRunner(
- final List> runners,
- final int index) {
- LOGGER.debug(PROCESS_NAME + ": Starting client " + index);
- final Thread runnerThread = getRunnerThread(runners);
- runnerThread.start();
- LOGGER.debug(PROCESS_NAME + ": Started client " + index);
- }
-
- private static Thread getRunnerThread(
- final List> runners) {
- final ClientConnectionRunner runner = new ClientConnectionRunner(connectionString, clientRunTime);
- final Thread runnerThread = new Thread(runner);
- runners.add(new AbstractMap.SimpleImmutableEntry<>(runner, runnerThread));
- return runnerThread;
- }
-
- private static void writeException(final Exception e) {
- LOGGER.error(PROCESS_NAME + ": Exception: " + e);
- LOGGER.error(Arrays.stream(e.getStackTrace())
- .map(StackTraceElement::toString)
- .collect(Collectors.joining(System.lineSeparator())));
- }
-
- private static class ClientConnectionRunner implements Runnable, AutoCloseable {
- public static final SecureRandom RANDOM = new SecureRandom();
- private volatile Exception exception = null;
- private final String connectionString;
- private final int waitTimeoutSECS;
-
- public ClientConnectionRunner(final String connectionString, final int waitTimeoutSECS) {
- this.connectionString = connectionString;
- this.waitTimeoutSECS = waitTimeoutSECS;
- }
-
- @Override
- public void run() {
- try (Connection connection = DriverManager.getConnection(connectionString)) {
- final boolean connected = connection.isValid(0);
- LOGGER.debug("Connection is valid: " + connected);
- assert connected;
- final int randomExtension = RANDOM.nextInt(Math.max(1, (int) (0.25 * waitTimeoutSECS)));
- TimeUnit.SECONDS.sleep(waitTimeoutSECS + randomExtension);
- } catch (Exception e) {
- exception = e;
- }
- }
-
- @Override
- public void close() throws Exception {
- }
-
- public Exception getException() {
- return exception;
- }
- }
-}
diff --git a/src/markdown/support/troubleshooting-guide.md b/src/markdown/support/troubleshooting-guide.md
index 1cf6b2f2..685e8f7a 100644
--- a/src/markdown/support/troubleshooting-guide.md
+++ b/src/markdown/support/troubleshooting-guide.md
@@ -152,21 +152,6 @@ The online security resources may give a pointer how to fix this.
[Maintaining the Known Hosts File](../setup/maintain_known_hosts.md) to add one
or more entries to your known hosts file.
-### Unexplained timeout or disconnection from BI tool while using SSH tunnel properties
-
-#### What to look for:
-
-- Timeout connecting to the database server.
-- Connection refused on local machine.
-- Disconnection error.
-
-#### What to do:
-
-1. Close all BI tools that might be using the Amazon DocumentDB JDBC Driver.
-2. Delete all the SSH lock file folders found under `~/.documentdb/sshTunnelLocks`.
-3. If you are unable to delete some of these folders due to files being locked, restart
-your computer to force the files to become unlocked. Then repeat step 2.
-
## Schema Issues
### Schema Out of Date
diff --git a/src/test/java/software/amazon/documentdb/jdbc/DocumentDbConnectionTest.java b/src/test/java/software/amazon/documentdb/jdbc/DocumentDbConnectionTest.java
index fd5d36c8..a2e2d89b 100644
--- a/src/test/java/software/amazon/documentdb/jdbc/DocumentDbConnectionTest.java
+++ b/src/test/java/software/amazon/documentdb/jdbc/DocumentDbConnectionTest.java
@@ -16,7 +16,6 @@
package software.amazon.documentdb.jdbc;
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
@@ -25,8 +24,6 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import software.amazon.documentdb.jdbc.common.test.DocumentDbFlapDoodleExtension;
import software.amazon.documentdb.jdbc.common.test.DocumentDbFlapDoodleTest;
import software.amazon.documentdb.jdbc.common.test.DocumentDbTestEnvironment;
@@ -35,34 +32,26 @@
import software.amazon.documentdb.jdbc.metadata.DocumentDbSchema;
import software.amazon.documentdb.jdbc.persist.DocumentDbSchemaReader;
import software.amazon.documentdb.jdbc.persist.DocumentDbSchemaWriter;
-import software.amazon.documentdb.jdbc.sshtunnel.DocumentDbSshTunnelServer;
-
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStreamReader;
-import java.io.UncheckedIOException;
-import java.net.URISyntaxException;
-import java.nio.charset.StandardCharsets;
+
+import java.security.SecureRandom;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
-import java.time.Duration;
import java.time.Instant;
+import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
-import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.isNullOrWhitespace;
-
@ExtendWith(DocumentDbFlapDoodleExtension.class)
public class DocumentDbConnectionTest extends DocumentDbFlapDoodleTest {
- private static final Logger LOGGER = LoggerFactory.getLogger(DocumentDbConnectionTest.class);
-
private static final String HOSTNAME = "localhost";
private static final String USERNAME = "user";
private static final String PASSWORD = "password";
@@ -490,89 +479,37 @@ void testMultiProcessConnections(final DocumentDbTestEnvironment environment) th
}
environment.start();
- final String connectionString = DocumentDbConnectionPropertiesTest
- .buildInternalSshTunnelConnectionString(environment);
- final int maxWaitTimePerClient = 5;
- final List commandLine = getCommandLine(connectionString, maxWaitTimePerClient);
- final List processes = startClientRunnerProcesses(commandLine);
- final Instant timeoutTime = Instant.now().plus(Duration.ofSeconds(maxWaitTimePerClient * 3));
- assertProcessesCompleteNormally(processes, timeoutTime);
- }
-
- private static List getCommandLine(final String connectionString, final int maxWaitTimePerClient)
- throws SQLException, URISyntaxException {
- final int numberOfClientsPerProcess = 5;
- // This class name is provided in text because I don't want to mark the class public.
- // For testing purposes, the class needs to be in the 'main' distribution - not the 'test' distribution.
- final String clientRunnerClassName = DocumentDbSshTunnelServer.class.getPackage().getName()
- + "." + "DocumentDbSshTunnelTestClientRunner";
- return DocumentDbSshTunnelServer.getJavaCommand(
- clientRunnerClassName,
- connectionString,
- String.valueOf(numberOfClientsPerProcess),
- String.valueOf(maxWaitTimePerClient));
- }
-
- @SuppressFBWarnings("COMMAND_INJECTION")
- private static List startClientRunnerProcesses(final List commandLine) throws IOException {
- final List processes = new ArrayList<>();
- final int processCount = 5;
- for (int i = 0; i < processCount; i++) {
- processes.add(new ProcessBuilder(commandLine).start());
+ final int numberOfConnections = 100;
+ final List runners = new ArrayList<>();
+ final List threads = new ArrayList<>();
+ final DocumentDbConnectionProperties internalSSHTunnelProperties = getInternalSSHTunnelProperties(environment);
+ for (int i = 0; i < numberOfConnections; i++) {
+ final Runner runner = new Runner(internalSSHTunnelProperties);
+ final Thread thread = new Thread(runner);
+ runners.add(runner);
+ threads.add(thread);
}
- return processes;
- }
-
- private static void assertProcessesCompleteNormally(final List processes, final Instant timeoutTime)
- throws InterruptedException, IOException {
- boolean timeoutReached = false;
- for (Process process : processes) {
- synchronized (process) {
- while (process.isAlive()) {
- final Instant now = Instant.now();
- if (now.isAfter(timeoutTime)) {
- LOGGER.debug("Timeout reached.");
- timeoutReached = true;
- break;
- }
- process.waitFor(500, TimeUnit.MILLISECONDS);
- }
- if (!timeoutReached) {
- LOGGER.debug("Closed before timeout.");
- }
- // Forceful shutdown, if necessary.
- if (process.isAlive()) {
- LOGGER.debug("Destroying process.");
- process.destroy();
- }
- if (process.isAlive()) {
- LOGGER.debug("Forcibly destroying process.");
- process.destroyForcibly();
- }
- try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) {
- final String stdOut = bufferedReader.lines().collect(Collectors.joining(System.lineSeparator()));
- if (!isNullOrWhitespace(stdOut)) {
- LOGGER.debug("Process output: '" + stdOut + "'");
- }
- } catch (IOException | UncheckedIOException e) {
- // Ignore exceptions - might already be closed.
- }
- try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getErrorStream(), StandardCharsets.UTF_8))) {
- final String stdErr = bufferedReader.lines().collect(Collectors.joining(System.lineSeparator()));
- if (!isNullOrWhitespace(stdErr)) {
- LOGGER.debug("Process error: '" + stdErr + "'");
- }
- } catch (IOException | UncheckedIOException e) {
- // Ignore exceptions - might already be closed.
+ for (final Thread thread : threads) {
+ thread.start();
+ }
+ while (threads.size() > 0) {
+ TimeUnit.MILLISECONDS.sleep(100);
+ for (int i = threads.size() - 1; i >= 0; i--) {
+ final Thread thread = threads.get(i);
+ if (!thread.isAlive()) {
+ thread.join();
+ threads.remove(i);
}
}
}
- // If client runner ran successfully, it returns zero (0) exit value.
- for (Process process : processes) {
- Assertions.assertEquals(0, process.exitValue());
+ for (final Runner runner : runners) {
+ final Queue exceptions = runner.exceptions;
+ Assertions.assertEquals(0, exceptions.size(),
+ () -> exceptions.stream()
+ .map(e -> e.getMessage())
+ .collect(Collectors.joining("; ")));
+
}
- // Double check we haven't timed-out.
- Assertions.assertFalse(timeoutReached);
}
private Stream getDocumentDb40SshTunnelEnvironmentSourceOrNull() {
@@ -612,4 +549,43 @@ public static DocumentDbConnectionProperties getInternalSSHTunnelProperties(
properties.setSshStrictHostKeyChecking("false");
return properties;
}
+
+ private static class Runner implements Runnable {
+ private static final SecureRandom RANDOM = new SecureRandom();
+ private final DocumentDbConnectionProperties properties;
+ private final Queue exceptions = new ConcurrentLinkedQueue<>();
+
+ Runner(final DocumentDbConnectionProperties properties) {
+ this.properties = properties;
+ }
+
+ @Override
+ public void run() {
+ final int timeToWaitSECS = RANDOM.nextInt(5) + 1;
+ final Instant timeoutTime = Instant.now().plus(timeToWaitSECS, ChronoUnit.SECONDS);
+ DocumentDbConnection connection = null;
+ try {
+ connection = new DocumentDbConnection(properties);
+ while (timeoutTime.isAfter(Instant.now())) {
+ connection.isValid(1);
+ TimeUnit.MILLISECONDS.sleep(100);
+ }
+ } catch (Exception e) {
+ exceptions.add(e);
+ } finally {
+ if (connection != null) {
+ try {
+ connection.close();
+ Assertions.assertFalse(connection.isValid(1));
+ } catch (Exception e) {
+ exceptions.add(e);
+ }
+ }
+ }
+ }
+
+ Queue getExceptions() {
+ return exceptions;
+ }
+ }
}
diff --git a/src/test/java/software/amazon/documentdb/jdbc/DocumentDbMainTest.java b/src/test/java/software/amazon/documentdb/jdbc/DocumentDbMainTest.java
index 073aaba2..4cf7bedf 100644
--- a/src/test/java/software/amazon/documentdb/jdbc/DocumentDbMainTest.java
+++ b/src/test/java/software/amazon/documentdb/jdbc/DocumentDbMainTest.java
@@ -72,26 +72,11 @@ class DocumentDbMainTest {
.compile(NEW_DEFAULT_SCHEMA_ANY_VERSION_REGEX);
private DocumentDbConnectionProperties properties;
public static final Path USER_HOME_PATH = Paths.get(System.getProperty(USER_HOME_PROPERTY));
- private static final String DOC_DB_PRIV_KEY_FILE_PROPERTY = "DOC_DB_PRIV_KEY_FILE";
- private static final String DOC_DB_USER_PROPERTY = "DOC_DB_USER";
- private static final String DOC_DB_HOST_PROPERTY = "DOC_DB_HOST";
private static Stream getTestEnvironments() {
return DocumentDbTestEnvironmentFactory.getConfiguredEnvironments().stream();
}
- private static Stream getDocumentDb40SshTunnelEnvironmentSourceOrNull() {
- if (DocumentDbTestEnvironmentFactory.getConfiguredEnvironments().stream()
- .anyMatch(e -> e == DocumentDbTestEnvironmentFactory
- .getDocumentDb40SshTunnelEnvironment())) {
- return DocumentDbTestEnvironmentFactory.getConfiguredEnvironments().stream()
- .filter(e -> e == DocumentDbTestEnvironmentFactory
- .getDocumentDb40SshTunnelEnvironment());
- } else {
- return Stream.of((DocumentDbTestEnvironment) null);
- }
- }
-
@BeforeAll
static void beforeAll() throws Exception {
for (DocumentDbTestEnvironment environment : getTestEnvironments()
@@ -841,22 +826,6 @@ void testExportFileToDirectoryError(final DocumentDbTestEnvironment testEnvironm
}
}
- @ParameterizedTest(name = "testSshTunnelCommand - [{index}] - {arguments}")
- @MethodSource("getDocumentDb40SshTunnelEnvironmentSourceOrNull")
- void testSshTunnelCommand(final DocumentDbTestEnvironment testEnvironment) throws SQLException {
- // NOTE: a "null" environment means it isn't configured to run. So bypass.
- if (testEnvironment == null) {
- return;
- }
- setConnectionProperties(testEnvironment);
-
- final String connectionString = getSshConnectionString();
- final StringBuilder output = new StringBuilder();
- final String[] args = {"--ssh-tunnel", connectionString};
- DocumentDbMain.handleCommandLine(args, output);
- Assertions.assertEquals("", output.toString());
- }
-
private String createSimpleCollection(final DocumentDbTestEnvironment testEnvironment)
throws SQLException {
final String collectionName;
@@ -1047,20 +1016,4 @@ private static String getExpectedExportContent(
builder.append(" ]");
return builder.toString();
}
-
- static String getSshConnectionString() {
- final String docDbRemoteHost = System.getenv(DOC_DB_HOST_PROPERTY);
- final String docDbSshUserAndHost = System.getenv(DOC_DB_USER_PROPERTY);
- final int userSeparatorIndex = docDbSshUserAndHost.indexOf('@');
- final String sshUser = docDbSshUserAndHost.substring(0, userSeparatorIndex);
- final String sshHostname = docDbSshUserAndHost.substring(userSeparatorIndex + 1);
- final String docDbSshPrivKeyFile = System.getenv(DOC_DB_PRIV_KEY_FILE_PROPERTY);
- final DocumentDbConnectionProperties properties = new DocumentDbConnectionProperties();
- properties.setHostname(docDbRemoteHost);
- properties.setSshUser(sshUser);
- properties.setSshHostname(sshHostname);
- properties.setSshPrivateKeyFile(docDbSshPrivKeyFile);
- properties.setSshStrictHostKeyChecking(String.valueOf(false));
- return DocumentDbConnectionProperties.DOCUMENT_DB_SCHEME + properties.buildSshConnectionString();
- }
}
diff --git a/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClientTest.java b/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClientTest.java
index 9a733a54..282f0601 100644
--- a/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClientTest.java
+++ b/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelClientTest.java
@@ -29,6 +29,9 @@
import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.ValidationType.SSH_TUNNEL;
class DocumentDbSshTunnelClientTest {
+ private static final String DOC_DB_PRIV_KEY_FILE_PROPERTY = "DOC_DB_PRIV_KEY_FILE";
+ private static final String DOC_DB_USER_PROPERTY = "DOC_DB_USER";
+ private static final String DOC_DB_HOST_PROPERTY = "DOC_DB_HOST";
@Test
@Tag("remote-integration")
@@ -40,6 +43,7 @@ void testConstructorDestructor() throws Exception {
try {
client = new DocumentDbSshTunnelClient(properties);
server = client.getSshTunnelServer();
+ server.setCloseDelayMS(1000);
Assertions.assertTrue(client.getServiceListeningPort() > 0);
TimeUnit.SECONDS.sleep(1);
Assertions.assertTrue(client.isServerAlive());
@@ -49,7 +53,7 @@ void testConstructorDestructor() throws Exception {
if (client != null) {
client.close();
// This is the only client, so server will shut down.
- TimeUnit.SECONDS.sleep(1);
+ TimeUnit.MILLISECONDS.sleep(server.getCloseDelayMS() + 500);
Assertions.assertNotNull(server);
Assertions.assertFalse(client.isServerAlive());
}
@@ -83,8 +87,17 @@ void testMultipleClientsSameServer() throws Exception {
clients.add(client);
}
} finally {
+ int clientCount = clients.size();
+ final DocumentDbSshTunnelServer server = clients.get(0).getSshTunnelServer();
+ server.setCloseDelayMS(0);
for (DocumentDbSshTunnelClient client : clients) {
client.close();
+ clientCount--;
+ if (clientCount > 0) {
+ Assertions.assertTrue(client.getSshTunnelServer().isAlive());
+ } else {
+ Assertions.assertFalse(client.getSshTunnelServer().isAlive());
+ }
}
}
}
@@ -98,8 +111,8 @@ void testInvalidSshHostnameConnectionTimeout() throws Exception {
final Exception e = Assertions.assertThrows(
SQLException.class,
() -> new DocumentDbSshTunnelClient(properties));
- Assertions.assertTrue(
- e.toString().startsWith("java.sql.SQLException: Error reported from SSH Tunnel service."));
+ Assertions.assertTrue(e.toString().startsWith(
+ "java.sql.SQLException: java.net.ConnectException: Connection timed out"));
}
@Test
@@ -111,8 +124,7 @@ void testInvalidSshUserAuthFail() throws Exception {
final Exception e = Assertions.assertThrows(
SQLException.class,
() -> new DocumentDbSshTunnelClient(properties));
- Assertions.assertEquals("java.sql.SQLException: Error reported from SSH Tunnel service."
- + " (Server exception detected: 'java.sql.SQLException: Auth fail for methods 'publickey,gssapi-keyex,gssapi-with-mic'')",
+ Assertions.assertEquals("java.sql.SQLException: Auth fail for methods 'publickey,gssapi-keyex,gssapi-with-mic'",
e.toString());
}
@@ -144,7 +156,23 @@ void testInvalidSshKnownHostsFileNotFound() throws Exception {
}
private static DocumentDbConnectionProperties getConnectionProperties() throws SQLException {
- final String connectionString = DocumentDbSshTunnelServiceTest.getConnectionString();
+ final String connectionString = getConnectionString();
return DocumentDbConnectionProperties.getPropertiesFromConnectionString(connectionString, SSH_TUNNEL);
}
+
+ static String getConnectionString() {
+ final String docDbRemoteHost = System.getenv(DOC_DB_HOST_PROPERTY);
+ final String docDbSshUserAndHost = System.getenv(DOC_DB_USER_PROPERTY);
+ final int userSeparatorIndex = docDbSshUserAndHost.indexOf('@');
+ final String sshUser = docDbSshUserAndHost.substring(0, userSeparatorIndex);
+ final String sshHostname = docDbSshUserAndHost.substring(userSeparatorIndex + 1);
+ final String docDbSshPrivKeyFile = System.getenv(DOC_DB_PRIV_KEY_FILE_PROPERTY);
+ final DocumentDbConnectionProperties properties = new DocumentDbConnectionProperties();
+ properties.setHostname(docDbRemoteHost);
+ properties.setSshUser(sshUser);
+ properties.setSshHostname(sshHostname);
+ properties.setSshPrivateKeyFile(docDbSshPrivKeyFile);
+ properties.setSshStrictHostKeyChecking(String.valueOf(false));
+ return DocumentDbConnectionProperties.DOCUMENT_DB_SCHEME + properties.buildSshConnectionString();
+ }
}
diff --git a/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServerTest.java b/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServerTest.java
index b9b63705..07034c25 100644
--- a/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServerTest.java
+++ b/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServerTest.java
@@ -21,39 +21,315 @@
import org.junit.jupiter.api.Test;
import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties;
+import java.sql.SQLException;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.ValidationType.SSH_TUNNEL;
class DocumentDbSshTunnelServerTest {
+ private final Object mutex = new Object();
+
@Test
@Tag("remote-integration")
- void testEnsureStarted() throws Exception {
- final String connectionString = DocumentDbSshTunnelServiceTest.getConnectionString();
+ void testAddRemoveClient() throws Exception {
+ final String connectionString = DocumentDbSshTunnelClientTest.getConnectionString();
final DocumentDbConnectionProperties properties =
DocumentDbConnectionProperties.getPropertiesFromConnectionString(connectionString, SSH_TUNNEL);
- DocumentDbSshTunnelServer server = null;
- final int timeoutSECS = 3;
+ final DocumentDbSshTunnelServer server = DocumentDbSshTunnelServer.builder(
+ properties.getSshUser(),
+ properties.getSshHostname(),
+ properties.getSshPrivateKeyFile(),
+ properties.getHostname())
+ .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
+ .build();
+ final int timeoutSECS = 1;
try {
- server = DocumentDbSshTunnelServer.builder(
- properties.getSshUser(),
- properties.getSshHostname(),
- properties.getSshPrivateKeyFile(),
- properties.getHostname())
- .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
- .build();
- server.ensureStarted();
+ server.addClient();
Assertions.assertTrue(server.getServiceListeningPort() > 0);
TimeUnit.SECONDS.sleep(timeoutSECS);
- // No clients registered - so should exit
+ Assertions.assertTrue(server.isAlive());
+ } finally {
+ server.setCloseDelayMS(0);
+ Assertions.assertNotNull(server);
+ server.removeClient();
+ Assertions.assertEquals(0, server.getServiceListeningPort());
+ TimeUnit.SECONDS.sleep(timeoutSECS);
+ Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ // Extra remove is ignored.
+ Assertions.assertDoesNotThrow(server::removeClient);
+ Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ }
+ }
+
+ @Test
+ @Tag("remote-integration")
+ void testAddRemoveClientDelayedClose() throws Exception {
+ final String connectionString = DocumentDbSshTunnelClientTest.getConnectionString();
+ final DocumentDbConnectionProperties properties =
+ DocumentDbConnectionProperties.getPropertiesFromConnectionString(connectionString, SSH_TUNNEL);
+ final DocumentDbSshTunnelServer server = DocumentDbSshTunnelServer.builder(
+ properties.getSshUser(),
+ properties.getSshHostname(),
+ properties.getSshPrivateKeyFile(),
+ properties.getHostname())
+ .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
+ .build();
+ final int timeoutSECS = 1;
+ try {
+ server.addClient();
+ Assertions.assertTrue(server.getServiceListeningPort() > 0);
+ TimeUnit.SECONDS.sleep(timeoutSECS);
+ Assertions.assertTrue(server.isAlive());
+ } finally {
+ Assertions.assertNotNull(server);
+ final int closeDelayMS = 5000;
+ final int closeDelayTimeWithBuffer = closeDelayMS;
+ server.setCloseDelayMS(closeDelayMS);
+ server.removeClient();
+ final Instant expectedCloseTime = Instant.now().plus(closeDelayTimeWithBuffer, ChronoUnit.MILLIS);
+ while (Instant.now().isBefore(expectedCloseTime)) {
+ Assertions.assertTrue(server.getServiceListeningPort() != 0);
+ Assertions.assertTrue(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ TimeUnit.MILLISECONDS.sleep(100);
+ }
+ TimeUnit.MILLISECONDS.sleep(100);
+ Assertions.assertEquals(0, server.getServiceListeningPort());
+ TimeUnit.SECONDS.sleep(timeoutSECS);
+ Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ // Extra remove is ignored.
+ Assertions.assertDoesNotThrow(server::removeClient);
+ Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ }
+ }
+
+ @Test
+ @Tag("remote-integration")
+ void testAddRemoveClientBeforeDelayedClose() throws Exception {
+ final String connectionString = DocumentDbSshTunnelClientTest.getConnectionString();
+ final DocumentDbConnectionProperties properties =
+ DocumentDbConnectionProperties.getPropertiesFromConnectionString(connectionString, SSH_TUNNEL);
+ final DocumentDbSshTunnelServer server = DocumentDbSshTunnelServer.builder(
+ properties.getSshUser(),
+ properties.getSshHostname(),
+ properties.getSshPrivateKeyFile(),
+ properties.getHostname())
+ .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
+ .build();
+ final int closeDelayMS = 2000;
+ final int closeDelayTimeWithBuffer = closeDelayMS;
+ final int timeoutSECS = 1;
+ try {
+ Assertions.assertNotNull(server);
+ server.setCloseDelayMS(closeDelayMS);
+ server.addClient();
+ Assertions.assertEquals(1, server.getClientCount());
+ Assertions.assertTrue(server.getServiceListeningPort() > 0);
+ TimeUnit.MILLISECONDS.sleep(closeDelayMS * 2);
+ Assertions.assertTrue(server.isAlive());
+ server.removeClient();
+ Assertions.assertEquals(0, server.getClientCount());
+ Assertions.assertTrue(server.isAlive());
+ TimeUnit.MILLISECONDS.sleep(closeDelayMS / 2);
+ server.addClient();
+ Assertions.assertEquals(1, server.getClientCount());
+ Assertions.assertTrue(server.isAlive());
+ TimeUnit.MILLISECONDS.sleep(closeDelayMS * 2);
+ Assertions.assertTrue(server.isAlive());
+ } finally {
+ Assertions.assertNotNull(server);
+ server.setCloseDelayMS(closeDelayMS);
+ server.removeClient();
+ final Instant expectedCloseTime = Instant.now().plus(closeDelayTimeWithBuffer, ChronoUnit.MILLIS);
+ while (Instant.now().isBefore(expectedCloseTime)) {
+ Assertions.assertTrue(server.getServiceListeningPort() != 0);
+ Assertions.assertTrue(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ TimeUnit.MILLISECONDS.sleep(100);
+ }
+ TimeUnit.MILLISECONDS.sleep(100);
+ Assertions.assertEquals(0, server.getServiceListeningPort());
+ TimeUnit.SECONDS.sleep(timeoutSECS);
Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ // Extra remove is ignored.
+ Assertions.assertDoesNotThrow(server::removeClient);
+ Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ }
+ }
+
+ @Test
+ @Tag("remote-integration")
+ void testAddRemoveClientAfterDelayedClose() throws Exception {
+ final String connectionString = DocumentDbSshTunnelClientTest.getConnectionString();
+ final DocumentDbConnectionProperties properties =
+ DocumentDbConnectionProperties.getPropertiesFromConnectionString(connectionString, SSH_TUNNEL);
+ final DocumentDbSshTunnelServer server = DocumentDbSshTunnelServer.builder(
+ properties.getSshUser(),
+ properties.getSshHostname(),
+ properties.getSshPrivateKeyFile(),
+ properties.getHostname())
+ .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
+ .build();
+ final int closeDelayMS = 2000;
+ final int closeDelayTimeWithBuffer = closeDelayMS;
+ final int timeoutSECS = 1;
+ try {
+ Assertions.assertNotNull(server);
+ server.setCloseDelayMS(closeDelayMS);
+ server.addClient();
+ Assertions.assertEquals(1, server.getClientCount());
+ Assertions.assertTrue(server.getServiceListeningPort() > 0);
+ TimeUnit.MILLISECONDS.sleep(closeDelayMS * 2);
+ Assertions.assertTrue(server.isAlive());
+ server.removeClient();
+ Assertions.assertEquals(0, server.getClientCount());
+ Assertions.assertTrue(server.isAlive());
+ TimeUnit.MILLISECONDS.sleep(closeDelayMS * 2);
+ Assertions.assertFalse(server.isAlive());
+ server.addClient();
+ Assertions.assertEquals(1, server.getClientCount());
+ Assertions.assertTrue(server.isAlive());
+ TimeUnit.MILLISECONDS.sleep(closeDelayMS * 2);
+ Assertions.assertTrue(server.isAlive());
} finally {
Assertions.assertNotNull(server);
- server.close();
+ server.setCloseDelayMS(closeDelayMS);
+ server.removeClient();
+ final Instant expectedCloseTime = Instant.now().plus(closeDelayTimeWithBuffer, ChronoUnit.MILLIS);
+ while (Instant.now().isBefore(expectedCloseTime)) {
+ Assertions.assertTrue(server.getServiceListeningPort() != 0);
+ Assertions.assertTrue(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ TimeUnit.MILLISECONDS.sleep(100);
+ }
+ TimeUnit.MILLISECONDS.sleep(100);
Assertions.assertEquals(0, server.getServiceListeningPort());
TimeUnit.SECONDS.sleep(timeoutSECS);
Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ // Extra remove is ignored.
+ Assertions.assertDoesNotThrow(server::removeClient);
+ Assertions.assertFalse(server.isAlive());
+ Assertions.assertEquals(0, server.getClientCount());
+ }
+ }
+
+ @Test
+ @Tag("remote-integration")
+ void testAddRemoveClientMultiThreaded() throws SQLException, InterruptedException {
+ final int numOfThreads = 10;
+ final List threads = new ArrayList<>();
+ final List runners = new ArrayList<>();
+ final String connectionString = DocumentDbSshTunnelClientTest.getConnectionString();
+ final DocumentDbConnectionProperties properties =
+ DocumentDbConnectionProperties.getPropertiesFromConnectionString(connectionString, SSH_TUNNEL);
+
+ final DocumentDbSshTunnelServer server = DocumentDbSshTunnelServer.builder(
+ properties.getSshUser(),
+ properties.getSshHostname(),
+ properties.getSshPrivateKeyFile(),
+ properties.getHostname())
+ .sshStrictHostKeyChecking(properties.getSshStrictHostKeyChecking())
+ .build();
+ Assertions.assertNotNull(server);
+ server.setCloseDelayMS(0);
+
+ // Create all the runners and assign them to a thread.
+ for (int i = 0; i < numOfThreads; i++) {
+ final int runtimeSecs = numOfThreads - i;
+ final Runner runner = new Runner(runtimeSecs, server);
+ final Thread threadRunner = new Thread(runner);
+ runners.add(runner);
+ threads.add(threadRunner);
+ }
+
+ // Start all the threads.
+ for (int i = 0; i < numOfThreads; i++) {
+ threads.get(i).start();
+ }
+
+ // Wait for the threads to complete.
+ TimeUnit.SECONDS.sleep(1);
+ while (threads.size() > 0) {
+ TimeUnit.MILLISECONDS.sleep(100);
+ synchronized (mutex) {
+ // Allow thread to exit after releasing the MUTEX.
+ TimeUnit.MILLISECONDS.sleep(10);
+ final long clientCount = server.getClientCount();
+ int threadCount = 0;
+ for (int i = threads.size() - 1; i >= 0; i--) {
+ if (threads.get(i).isAlive()) {
+ threadCount++;
+ Assertions.assertTrue(server.isAlive());
+ } else {
+ threads.get(i).join();
+ threads.remove(i);
+ }
+ }
+ Assertions.assertEquals(clientCount, threadCount);
+ Assertions.assertTrue((clientCount > 0 && server.isAlive()) || !server.isAlive());
+ }
+ }
+
+ // Ensure no more clients and no longer alive.
+ Assertions.assertEquals(0, server.getClientCount());
+ Assertions.assertFalse(server.isAlive());
+
+ // Ensure clients didn't throw any exceptions.
+ for (final Runner runner : runners) {
+ Assertions.assertEquals(0, runner.getExceptions().size(),
+ () -> runner.getExceptions().stream()
+ .map(Throwable::getMessage)
+ .collect(Collectors.joining("; ")));
+ }
+ }
+
+ private class Runner implements Runnable {
+ private final int runtimeSecs;
+ private final DocumentDbSshTunnelServer server;
+ private final Queue exceptions = new ConcurrentLinkedDeque<>();
+
+ public Runner(final int runtimeSecs, final DocumentDbSshTunnelServer server) {
+ this.runtimeSecs = runtimeSecs;
+ this.server = server;
+ }
+
+ public Queue getExceptions() {
+ return exceptions;
+ }
+
+ @Override
+ public void run() {
+ try {
+ synchronized (mutex) {
+ server.addClient();
+ }
+ TimeUnit.SECONDS.sleep(runtimeSecs);
+ } catch (Exception e) {
+ exceptions.add(e);
+ } finally {
+ try {
+ synchronized (mutex) {
+ server.removeClient();
+ }
+ } catch (Exception e) {
+ exceptions.add(e);
+ }
+ }
}
}
}
diff --git a/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServiceTest.java b/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServiceTest.java
deleted file mode 100644
index 6c361949..00000000
--- a/src/test/java/software/amazon/documentdb/jdbc/sshtunnel/DocumentDbSshTunnelServiceTest.java
+++ /dev/null
@@ -1,246 +0,0 @@
-/*
- * Copyright <2022> Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License").
- * You may not use this file except in compliance with the License.
- * A copy of the License is located at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * or in the "license" file accompanying this file. This file is distributed
- * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- *
- */
-
-package software.amazon.documentdb.jdbc.sshtunnel;
-
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.DisplayName;
-import org.junit.jupiter.api.Tag;
-import org.junit.jupiter.api.Test;
-import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.channels.FileLock;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.nio.file.StandardOpenOption;
-import java.util.UUID;
-
-import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.getDocumentDbSearchPaths;
-import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.getPath;
-import static software.amazon.documentdb.jdbc.sshtunnel.DocumentDbSshTunnelLock.getClassPathLocationName;
-import static software.amazon.documentdb.jdbc.sshtunnel.DocumentDbSshTunnelLock.getDocumentdbHomePathName;
-import static software.amazon.documentdb.jdbc.sshtunnel.DocumentDbSshTunnelLock.getUserHomePathName;
-
-class DocumentDbSshTunnelServiceTest {
- private static final String DOC_DB_PRIV_KEY_FILE_PROPERTY = "DOC_DB_PRIV_KEY_FILE";
- private static final String DOC_DB_USER_PROPERTY = "DOC_DB_USER";
- private static final String DOC_DB_HOST_PROPERTY = "DOC_DB_HOST";
-
- @Test
- @Tag("remote-integration")
- @DisplayName("Tests that SSH tunnel service can be started and stays alive while client lock exists.")
- void testRun() throws Exception {
- final String connectionString = getConnectionString();
- String sshPropertiesHashString = null;
- final int waitTimeMS = 2000;
-
- try (DocumentDbSshTunnelService service = new DocumentDbSshTunnelService(connectionString)) {
- // Ensure the lock directory is empty.
- sshPropertiesHashString = service.getSshPropertiesHashString();
- DocumentDbSshTunnelLock.deleteLockDirectory(sshPropertiesHashString);
-
- // Prepare to create a client lock.
- final UUID unique = UUID.randomUUID();
- final Path clientLockPath = DocumentDbSshTunnelLock.getClientLockPath(
- unique, service.getSshPropertiesHashString());
- DocumentDbMultiThreadFileChannel clientChannelCopy = null;
- FileLock clientLockCopy = null;
-
- try {
- final Path parentPath = clientLockPath.getParent();
- assert parentPath != null;
- Files.createDirectories(parentPath);
- // Create and lock the client lock file.
- final DocumentDbMultiThreadFileChannel clientChannel = DocumentDbMultiThreadFileChannel.open(
- clientLockPath, StandardOpenOption.CREATE_NEW, StandardOpenOption.WRITE);
- final FileLock clientLock = clientChannel.lock();
- clientChannelCopy = clientChannel;
- clientLockCopy = clientLock;
-
- // Start the service thread and confirm it is alive.
- final Thread serviceThread = startServiceAndValidateIsAlive(service);
-
- // Close the client lock to signal the service
- final Exception closeException = closeClientLockInGlobalLock(
- service, clientLockPath, clientChannel, clientLock);
- Assertions.assertNull(closeException);
-
- // Wait and confirm service has stopped.
- serviceThread.join(waitTimeMS);
- Assertions.assertFalse(serviceThread.isAlive());
-
- // Validate there are no returned exceptions.
- validateExceptions(service);
- } finally {
- releaseResources(service, clientChannelCopy, clientLockCopy);
- }
- } finally {
- // Clean-up the lock file directory
- assert sshPropertiesHashString != null;
- DocumentDbSshTunnelLock.deleteLockDirectory(sshPropertiesHashString);
- }
- }
-
- @Test()
- @DisplayName("Tests the getPath method.")
- @SuppressFBWarnings("PATH_TRAVERSAL_IN")
- void testGetPath() throws IOException {
- final String tempFilename1 = UUID.randomUUID().toString();
-
- // Test that it will return using the "current directory"
- final Path path1 = getPath(tempFilename1);
- Assertions.assertEquals(Paths.get(tempFilename1).toAbsolutePath(), path1);
-
- // Test that it will use the user's home path
- final Path path2 = getPath("~/" + tempFilename1);
- Assertions.assertEquals(Paths.get(getUserHomePathName(), tempFilename1), path2);
-
- // Test that it will use the user's home path
- Path homeTempFilePath = null;
- try {
- homeTempFilePath = Paths.get(getUserHomePathName(), tempFilename1);
- Assertions.assertTrue(homeTempFilePath.toFile().createNewFile());
- final Path path3 = getPath(tempFilename1, getDocumentDbSearchPaths());
- Assertions.assertEquals(Paths.get(getUserHomePathName(), tempFilename1), path3);
- } finally {
- Assertions.assertTrue(homeTempFilePath != null && homeTempFilePath.toFile().delete());
- }
-
- // Test that it will use the .documentdb folder under the user's home path
- Path documentDbTempFilePath = null;
- try {
- documentDbTempFilePath = Paths.get(getDocumentdbHomePathName(), tempFilename1);
- final File documentDbDirectory = Paths.get(getDocumentdbHomePathName()).toFile();
- if (!documentDbDirectory.exists()) {
- Assertions.assertTrue(documentDbDirectory.mkdir());
- }
- Assertions.assertTrue(documentDbTempFilePath.toFile().createNewFile());
- final Path path4 = getPath(tempFilename1, getDocumentDbSearchPaths());
- Assertions.assertEquals(Paths.get(getDocumentdbHomePathName(), tempFilename1), path4);
- } finally {
- Assertions.assertTrue(documentDbTempFilePath != null && documentDbTempFilePath.toFile().delete());
- }
-
- // Test that it will use the .documentdb folder under the user's home path
- Path classPathParentTempFilePath = null;
- try {
- classPathParentTempFilePath = Paths.get(getClassPathLocationName(), tempFilename1);
- Assertions.assertTrue(classPathParentTempFilePath.toFile().createNewFile());
- final Path path5 = getPath(tempFilename1, getDocumentDbSearchPaths());
- Assertions.assertEquals(Paths.get(getClassPathLocationName(), tempFilename1), path5);
- } finally {
- Assertions.assertTrue(classPathParentTempFilePath != null && classPathParentTempFilePath.toFile().delete());
- }
-
- // Test that will recognize and use an absolute path
- File tempFile = null;
- try {
- tempFile = File.createTempFile("documentdb", ".tmp");
- final Path path5 = getPath(tempFile.getAbsolutePath());
- Assertions.assertEquals(Paths.get(tempFile.getAbsolutePath()), path5);
- } finally {
- Assertions.assertTrue(tempFile != null && tempFile.delete());
- }
- }
-
- private static void releaseResources(
- final DocumentDbSshTunnelService service,
- final DocumentDbMultiThreadFileChannel clientChannelCopy,
- final FileLock clientLockCopy) throws Exception {
- if (clientLockCopy != null && clientLockCopy.isValid()) {
- clientLockCopy.close();
- }
- if (clientChannelCopy != null && clientChannelCopy.isOpen()) {
- clientChannelCopy.close();
- }
- if (service != null) {
- service.close();
- }
- }
-
- private static Thread startServiceAndValidateIsAlive(final DocumentDbSshTunnelService service)
- throws InterruptedException {
- final int waitTimeMS = 2000;
- final Thread serviceThread = startThread(service);
- serviceThread.join(waitTimeMS);
- validateExceptions(service);
- Assertions.assertTrue(serviceThread.isAlive());
- validateExceptions(service);
- serviceThread.join(waitTimeMS);
- Assertions.assertTrue(serviceThread.isAlive());
- return serviceThread;
- }
-
- private static Exception closeClientLockInGlobalLock(
- final DocumentDbSshTunnelService service,
- final Path clientLockPath,
- final DocumentDbMultiThreadFileChannel clientChannel,
- final FileLock clientLock) throws Exception {
- return DocumentDbSshTunnelLock.runInGlobalLock(
- service.getSshPropertiesHashString(),
- () -> {
- Exception exception = null;
- try {
- clientLock.close();
- clientChannel.close();
- Files.deleteIfExists(clientLockPath);
- } catch (Exception e) {
- exception = e;
- }
- return exception;
- });
- }
-
- static String getConnectionString() {
- final String docDbRemoteHost = System.getenv(DOC_DB_HOST_PROPERTY);
- final String docDbSshUserAndHost = System.getenv(DOC_DB_USER_PROPERTY);
- final int userSeparatorIndex = docDbSshUserAndHost.indexOf('@');
- final String sshUser = docDbSshUserAndHost.substring(0, userSeparatorIndex);
- final String sshHostname = docDbSshUserAndHost.substring(userSeparatorIndex + 1);
- final String docDbSshPrivKeyFile = System.getenv(DOC_DB_PRIV_KEY_FILE_PROPERTY);
- final DocumentDbConnectionProperties properties = new DocumentDbConnectionProperties();
- properties.setHostname(docDbRemoteHost);
- properties.setSshUser(sshUser);
- properties.setSshHostname(sshHostname);
- properties.setSshPrivateKeyFile(docDbSshPrivKeyFile);
- properties.setSshStrictHostKeyChecking(String.valueOf(false));
- return DocumentDbConnectionProperties.DOCUMENT_DB_SCHEME + properties.buildSshConnectionString();
- }
-
- private static Thread startThread(final DocumentDbSshTunnelService service) {
- final Thread serviceThread = new Thread(service);
- serviceThread.setDaemon(true);
- serviceThread.start();
- return serviceThread;
- }
-
- private static void validateExceptions(final DocumentDbSshTunnelService service) {
- if (service.getExceptions().size() != 0) {
- for (Exception e : service.getExceptions()) {
- Assertions.assertInstanceOf(Exception.class, e);
- System.out.println(e.toString());
- for (StackTraceElement stackLine : e.getStackTrace()) {
- System.out.println(stackLine.toString());
- }
- }
- Assertions.fail();
- }
- }
-}