diff --git a/DESCRIPTION b/DESCRIPTION index 3ddc1556..09cd9069 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -48,7 +48,8 @@ Suggests: odbc, duckdb, pool, - ParallelLogger + ParallelLogger, + AzureStor License: Apache License VignetteBuilder: knitr URL: https://ohdsi.github.io/DatabaseConnector/, https://github.com/OHDSI/DatabaseConnector diff --git a/DatabaseConnector.Rproj b/DatabaseConnector.Rproj index 4084b1e0..c5ea6c45 100644 --- a/DatabaseConnector.Rproj +++ b/DatabaseConnector.Rproj @@ -1,4 +1,5 @@ Version: 1.0 +ProjectId: 9d51e576-41a3-432f-b696-8bfdc3eed676 RestoreWorkspace: No SaveWorkspace: No diff --git a/R/BulkLoad.R b/R/BulkLoad.R index 31a62d6e..5fa09837 100644 --- a/R/BulkLoad.R +++ b/R/BulkLoad.R @@ -62,6 +62,25 @@ checkBulkLoadCredentials <- function(connection) { return(FALSE) } return(TRUE) + } else if (dbms(connection) == "spark") { + envSet <- FALSE + container <- FALSE + + if (Sys.getenv("AZR_STORAGE_ACCOUNT") != "" && Sys.getenv("AZR_ACCOUNT_KEY") != "" && Sys.setenv("AZR_CONTAINER_NAME") != "") { + envSet <- TRUE + } + + # List storage containers to confirm the container + # specified in the configuration exists + ensure_installed("AzureStor") + azureEndpoint <- getAzureEndpoint() + containerList <- getAzureContainerNames(azureEndpoint) + + if (Sys.getenv("AZR_CONTAINER_NAME") %in% containerList) { + container <- TRUE + } + + return(envSet & container) } else { return(FALSE) } @@ -72,6 +91,18 @@ getHiveSshUser <- function() { return(if (sshUser == "") "root" else sshUser) } +getAzureEndpoint <- function() { + azureEndpoint <- AzureStor::storage_endpoint( + paste0("https://", Sys.getenv("AZR_STORAGE_ACCOUNT"), ".dfs.core.windows.net"), + key = Sys.getenv("AZR_ACCOUNT_KEY") + ) + return(azureEndpoint) +} + +getAzureContainerNames <- function(azureEndpoint) { + return(names(AzureStor::list_storage_containers(azureEndpoint))) +} + countRows <- function(connection, sqlTableName) { sql <- "SELECT COUNT(*) FROM @table" count <- renderTranslateQuerySql( @@ -354,3 +385,53 @@ bulkLoadPostgres <- function(connection, sqlTableName, sqlFieldNames, sqlDataTyp delta <- Sys.time() - startTime inform(paste("Bulk load to PostgreSQL took", signif(delta, 3), attr(delta, "units"))) } + +bulkLoadSpark <- function(connection, sqlTableName, data) { + ensure_installed("AzureStor") + logTrace(sprintf("Inserting %d rows into table '%s' using DataBricks bulk load", nrow(data), sqlTableName)) + start <- Sys.time() + + csvFileName <- tempfile("spark_insert_", fileext = ".csv") + write.csv(x = data, na = "", file = csvFileName, row.names = FALSE, quote = TRUE) + on.exit(unlink(csvFileName)) + + azureEndpoint <- getAzureEndpoint() + containers <- AzureStor::list_storage_containers(azureEndpoint) + targetContainer <- containers[[Sys.getenv("AZR_CONTAINER_NAME")]] + AzureStor::storage_upload( + targetContainer, + src=csvFileName, + dest=csvFileName + ) + + on.exit( + AzureStor::delete_storage_file( + targetContainer, + file = csvFileName, + confirm = FALSE + ), + add = TRUE + ) + + sql <- SqlRender::loadRenderTranslateSql( + sqlFilename = "sparkCopy.sql", + packageName = "DatabaseConnector", + dbms = "spark", + sqlTableName = sqlTableName, + fileName = basename(csvFileName), + azureAccountKey = Sys.getenv("AZR_ACCOUNT_KEY"), + azureStorageAccount = Sys.getenv("AZR_STORAGE_ACCOUNT") + ) + + tryCatch( + { + DatabaseConnector::executeSql(connection = connection, sql = sql, reportOverallTime = FALSE) + }, + error = function(e) { + abort("Error in DataBricks bulk upload. Please check DataBricks/Azure Storage access.") + } + ) + delta <- Sys.time() - start + inform(paste("Bulk load to DataBricks took", signif(delta, 3), attr(delta, "units"))) +} + diff --git a/R/InsertTable.R b/R/InsertTable.R index a4ccb9f4..5ac02ca3 100644 --- a/R/InsertTable.R +++ b/R/InsertTable.R @@ -121,6 +121,13 @@ validateInt64Insert <- function() { #' "some_aws_region", "AWS_BUCKET_NAME" = "some_bucket_name", "AWS_OBJECT_KEY" = "some_object_key", #' "AWS_SSE_TYPE" = "server_side_encryption_type"). #' +#' Spark (DataBricks): The MPP bulk loading relies upon the AzureStor library +#' to test a connection to an Azure ADLS Gen2 storage container using Azure credentials. +#' Credentials are configured directly into the System Environment using the +#' following keys: Sys.setenv("AZR_STORAGE_ACCOUNT" = +#' "some_azure_storage_account", "AZR_ACCOUNT_KEY" = "some_secret_account_key", "AZR_CONTAINER_NAME" = +#' "some_container_name"). +#' #' PDW: The MPP bulk loading relies upon the client #' having a Windows OS and the DWLoader exe installed, and the following permissions granted: --Grant #' BULK Load permissions - needed at a server level USE master; GRANT ADMINISTER BULK OPERATIONS TO @@ -308,6 +315,8 @@ insertTable.default <- function(connection, bulkLoadHive(connection, sqlTableName, sqlFieldNames, data) } else if (dbms == "postgresql") { bulkLoadPostgres(connection, sqlTableName, sqlFieldNames, sqlDataTypes, data) + } else if (dbms == "spark") { + bulkLoadSpark(connection, sqlTableName, data) } } else if (useCtasHack) { # Inserting using CTAS hack ---------------------------------------------------------------- diff --git a/extras/TestBulkLoad.R b/extras/TestBulkLoad.R index b0211b39..1c2c1388 100644 --- a/extras/TestBulkLoad.R +++ b/extras/TestBulkLoad.R @@ -114,3 +114,37 @@ all.equal(data, data2) renderTranslateExecuteSql(connection, "DROP TABLE scratch_mschuemi.insert_test;") disconnect(connection) + + +# Spark ------------------------------------------------------------------------------ +# Assumes Spark (DataBricks) environmental variables have been set +options(sqlRenderTempEmulationSchema = Sys.getenv("DATABRICKS_SCRATCH_SCHEMA")) +databricksConnectionString <- paste0("jdbc:databricks://", Sys.getenv('DATABRICKS_HOST'), "/default;transportMode=http;ssl=1;AuthMech=3;httpPath=", Sys.getenv('DATABRICKS_HTTP_PATH')) +connectionDetails <- createConnectionDetails(dbms = "spark", + connectionString = databricksConnectionString, + user = "token", + password = Sys.getenv("DATABRICKS_TOKEN")) + + +connection <- connect(connectionDetails) +system.time( + insertTable(connection = connection, + tableName = "scratch.scratch_asena5.insert_test", + data = data, + dropTableIfExists = TRUE, + createTable = TRUE, + tempTable = FALSE, + progressBar = TRUE, + camelCaseToSnakeCase = TRUE, + bulkLoad = TRUE) +) +data2 <- querySql(connection, "SELECT * FROM scratch.scratch_asena5.insert_test;", snakeCaseToCamelCase = TRUE, integer64AsNumeric = FALSE) + +data <- data[order(data$id), ] +data2 <- data2[order(data2$id), ] +row.names(data) <- NULL +row.names(data2) <- NULL +all.equal(data, data2) + +renderTranslateExecuteSql(connection, "DROP TABLE scratch.scratch_asena5.insert_test;") +disconnect(connection) diff --git a/inst/sql/sql_server/sparkCopy.sql b/inst/sql/sql_server/sparkCopy.sql new file mode 100644 index 00000000..e9b43853 --- /dev/null +++ b/inst/sql/sql_server/sparkCopy.sql @@ -0,0 +1,10 @@ +COPY INTO @sqlTableName +FROM 'abfss://@azureStorageAccount.dfs.core.windows.net/@fileName' +WITH ( + CREDENTIAL (AZURE_SAS_TOKEN = '@azureAccountKey') +) +FILEFORMAT = CSV +FORMAT_OPTIONS ( + 'header' = 'true', + 'inferSchema' = 'true' +);